Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import tempfile | |
| import os | |
| from scipy.io.wavfile import write | |
| from transformers import ( | |
| SpeechT5Processor, | |
| SpeechT5ForTextToSpeech, | |
| SpeechT5HifiGan | |
| ) | |
| # ========================= | |
| # Model loading | |
| # ========================= | |
| checkpoint = "Chithekitale/Chichewa_tts_v2" | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # Make all keys consistent | |
| speaker_embeddings = { | |
| "SPK1": "speaker_2.npy", | |
| "SPK2": "speaker_1.npy", | |
| "SPK3": "cmu_us_ksp_arctic-wav-arctic_b0087.npy", | |
| "SPK4": "cmu_us_rms_arctic-wav-arctic_b0353.npy", | |
| "SPK5": "cmu_us_slt_arctic-wav-arctic_a0508.npy", | |
| } | |
| SPEAKER_CHOICES = [ | |
| "SPK1 (female)", | |
| "SPK2 (male)", | |
| "SPK3 (male)", | |
| "SPK4 (male)", | |
| "SPK5 (female)" | |
| ] | |
| EXAMPLES = [ | |
| ["Ndapita, koma ndibweranso pompano.", "SPK1 (female)"], | |
| ["Koma apapa zikuoneka kuti ziyenda bwino.", "SPK2 (male)"], | |
| ["Ineyo ndikuona kuti sizizasithanso.", "SPK3 (male)"], | |
| ["Mwina kusogolo kuno anthu ena azalimba mtima, koma panopana ndakaika.", "SPK4 (male)"], | |
| ["Simungasankhe munthu oti bola linamukana.", "SPK5 (female)"], | |
| ["Kodi chimanga panopa chikugulisidwa zingati, kapena nanunso simukudziwa?", "SPK5 (female)"], | |
| ] | |
| SAMPLE_RATE = 16000 | |
| # ========================= | |
| # Helpers | |
| # ========================= | |
| def get_speaker_key(speaker_label: str) -> str: | |
| # "SPK1 (female)" -> "SPK1" | |
| return speaker_label.split()[0] | |
| def load_speaker_embedding(speaker: str) -> np.ndarray: | |
| speaker_key = get_speaker_key(speaker) | |
| if speaker_key not in speaker_embeddings: | |
| raise ValueError(f"Unknown speaker key: {speaker_key}") | |
| path = speaker_embeddings[speaker_key] | |
| try: | |
| speaker_embedding = np.load(path).astype(np.float32) | |
| except Exception as e: | |
| raise FileNotFoundError( | |
| f"Could not load speaker embedding file: {path}. Error: {e}" | |
| ) | |
| if speaker_embedding.ndim == 2: | |
| speaker_embedding = speaker_embedding.mean(axis=0) | |
| speaker_embedding = np.squeeze(speaker_embedding) | |
| if speaker_embedding.shape != (512,): | |
| raise ValueError( | |
| f"Unexpected speaker embedding shape after processing: " | |
| f"{speaker_embedding.shape}. Expected (512,)" | |
| ) | |
| return speaker_embedding | |
| def save_audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str: | |
| """ | |
| Save generated int16 audio to a temporary WAV file and return its path. | |
| """ | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") | |
| temp_file.close() | |
| write(temp_file.name, sample_rate, audio) | |
| return temp_file.name | |
| # ========================= | |
| # Inference | |
| # ========================= | |
| def predict(text, speaker): | |
| try: | |
| if not text or len(text.strip()) == 0: | |
| return None, None, "Please enter some Chichewa text." | |
| inputs = processor(text=text, return_tensors="pt") | |
| input_ids = inputs["input_ids"][..., :model.config.max_text_positions] | |
| speaker_embedding = load_speaker_embedding(speaker) | |
| speaker_embedding = torch.tensor( | |
| speaker_embedding, dtype=torch.float32 | |
| ).unsqueeze(0) | |
| with torch.no_grad(): | |
| speech = model.generate_speech( | |
| input_ids, | |
| speaker_embedding, | |
| vocoder=vocoder | |
| ) | |
| speech = speech.cpu().numpy() | |
| # Normalize safely before int16 conversion | |
| max_val = np.max(np.abs(speech)) | |
| if max_val > 0: | |
| speech = speech / max_val | |
| speech = (speech * 32767).astype(np.int16) | |
| # Save WAV file for downloading | |
| wav_path = save_audio_to_wav(speech, SAMPLE_RATE) | |
| status = f"Generated speech successfully using speaker: {speaker}" | |
| return (SAMPLE_RATE, speech), wav_path, status | |
| except Exception as e: | |
| return None, None, f"Error during generation: {str(e)}" | |
| def clear_all(): | |
| return "", "SPK1 (female)", None, None, "Ready." | |
| # ========================= | |
| # UI | |
| # ========================= | |
| custom_css = """ | |
| .gradio-container { | |
| max-width: 1100px !important; | |
| margin: 0 auto; | |
| } | |
| .hero { | |
| text-align: center; | |
| padding: 10px 0 0 0; | |
| } | |
| .section-note { | |
| font-size: 0.95rem; | |
| opacity: 0.9; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, title="Chichewa Speech Synthesis Demo") as demo: | |
| gr.HTML( | |
| """ | |
| <div class="hero"> | |
| <h1>Rule-Intergrated Chichewa Speech Synthesis</h1> | |
| <p class="section-note"> | |
| Enter Chichewa text, choose a speaker voice, and generate speech audio. | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Type Chichewa text here...", | |
| lines=6 | |
| ) | |
| speaker_input = gr.Radio( | |
| label="Speaker Voice", | |
| choices=SPEAKER_CHOICES, | |
| value="SPK1 (female)" | |
| ) | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Speech", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| status_box = gr.Textbox( | |
| label="System Status", | |
| value="Ready.", | |
| interactive=False | |
| ) | |
| with gr.Column(scale=5): | |
| audio_output = gr.Audio( | |
| label="Generated Speech", | |
| type="numpy", | |
| autoplay=False | |
| ) | |
| download_file = gr.File( | |
| label="Download Audio File" | |
| ) | |
| gr.Markdown("### Example Inputs") | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[text_input, speaker_input] | |
| ) | |
| generate_btn.click( | |
| fn=predict, | |
| inputs=[text_input, speaker_input], | |
| outputs=[audio_output, download_file, status_box], | |
| show_progress="full" | |
| ) | |
| clear_btn.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[text_input, speaker_input, audio_output, download_file, status_box] | |
| ) | |
| demo.launch() |