Proposed / app.py
Chithekitale's picture
Update app.py
7bb4523 verified
import gradio as gr
import numpy as np
import torch
import tempfile
import os
from scipy.io.wavfile import write
from transformers import (
SpeechT5Processor,
SpeechT5ForTextToSpeech,
SpeechT5HifiGan
)
# =========================
# Model loading
# =========================
checkpoint = "Chithekitale/Chichewa_tts_v2"
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
# Make all keys consistent
speaker_embeddings = {
"SPK1": "speaker_2.npy",
"SPK2": "speaker_1.npy",
"SPK3": "cmu_us_ksp_arctic-wav-arctic_b0087.npy",
"SPK4": "cmu_us_rms_arctic-wav-arctic_b0353.npy",
"SPK5": "cmu_us_slt_arctic-wav-arctic_a0508.npy",
}
SPEAKER_CHOICES = [
"SPK1 (female)",
"SPK2 (male)",
"SPK3 (male)",
"SPK4 (male)",
"SPK5 (female)"
]
EXAMPLES = [
["Ndapita, koma ndibweranso pompano.", "SPK1 (female)"],
["Koma apapa zikuoneka kuti ziyenda bwino.", "SPK2 (male)"],
["Ineyo ndikuona kuti sizizasithanso.", "SPK3 (male)"],
["Mwina kusogolo kuno anthu ena azalimba mtima, koma panopana ndakaika.", "SPK4 (male)"],
["Simungasankhe munthu oti bola linamukana.", "SPK5 (female)"],
["Kodi chimanga panopa chikugulisidwa zingati, kapena nanunso simukudziwa?", "SPK5 (female)"],
]
SAMPLE_RATE = 16000
# =========================
# Helpers
# =========================
def get_speaker_key(speaker_label: str) -> str:
# "SPK1 (female)" -> "SPK1"
return speaker_label.split()[0]
def load_speaker_embedding(speaker: str) -> np.ndarray:
speaker_key = get_speaker_key(speaker)
if speaker_key not in speaker_embeddings:
raise ValueError(f"Unknown speaker key: {speaker_key}")
path = speaker_embeddings[speaker_key]
try:
speaker_embedding = np.load(path).astype(np.float32)
except Exception as e:
raise FileNotFoundError(
f"Could not load speaker embedding file: {path}. Error: {e}"
)
if speaker_embedding.ndim == 2:
speaker_embedding = speaker_embedding.mean(axis=0)
speaker_embedding = np.squeeze(speaker_embedding)
if speaker_embedding.shape != (512,):
raise ValueError(
f"Unexpected speaker embedding shape after processing: "
f"{speaker_embedding.shape}. Expected (512,)"
)
return speaker_embedding
def save_audio_to_wav(audio: np.ndarray, sample_rate: int = SAMPLE_RATE) -> str:
"""
Save generated int16 audio to a temporary WAV file and return its path.
"""
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
write(temp_file.name, sample_rate, audio)
return temp_file.name
# =========================
# Inference
# =========================
def predict(text, speaker):
try:
if not text or len(text.strip()) == 0:
return None, None, "Please enter some Chichewa text."
inputs = processor(text=text, return_tensors="pt")
input_ids = inputs["input_ids"][..., :model.config.max_text_positions]
speaker_embedding = load_speaker_embedding(speaker)
speaker_embedding = torch.tensor(
speaker_embedding, dtype=torch.float32
).unsqueeze(0)
with torch.no_grad():
speech = model.generate_speech(
input_ids,
speaker_embedding,
vocoder=vocoder
)
speech = speech.cpu().numpy()
# Normalize safely before int16 conversion
max_val = np.max(np.abs(speech))
if max_val > 0:
speech = speech / max_val
speech = (speech * 32767).astype(np.int16)
# Save WAV file for downloading
wav_path = save_audio_to_wav(speech, SAMPLE_RATE)
status = f"Generated speech successfully using speaker: {speaker}"
return (SAMPLE_RATE, speech), wav_path, status
except Exception as e:
return None, None, f"Error during generation: {str(e)}"
def clear_all():
return "", "SPK1 (female)", None, None, "Ready."
# =========================
# UI
# =========================
custom_css = """
.gradio-container {
max-width: 1100px !important;
margin: 0 auto;
}
.hero {
text-align: center;
padding: 10px 0 0 0;
}
.section-note {
font-size: 0.95rem;
opacity: 0.9;
}
"""
with gr.Blocks(css=custom_css, title="Chichewa Speech Synthesis Demo") as demo:
gr.HTML(
"""
<div class="hero">
<h1>Rule-Intergrated Chichewa Speech Synthesis</h1>
<p class="section-note">
Enter Chichewa text, choose a speaker voice, and generate speech audio.
</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=5):
text_input = gr.Textbox(
label="Input Text",
placeholder="Type Chichewa text here...",
lines=6
)
speaker_input = gr.Radio(
label="Speaker Voice",
choices=SPEAKER_CHOICES,
value="SPK1 (female)"
)
with gr.Row():
generate_btn = gr.Button("Generate Speech", variant="primary")
clear_btn = gr.Button("Clear")
status_box = gr.Textbox(
label="System Status",
value="Ready.",
interactive=False
)
with gr.Column(scale=5):
audio_output = gr.Audio(
label="Generated Speech",
type="numpy",
autoplay=False
)
download_file = gr.File(
label="Download Audio File"
)
gr.Markdown("### Example Inputs")
gr.Examples(
examples=EXAMPLES,
inputs=[text_input, speaker_input]
)
generate_btn.click(
fn=predict,
inputs=[text_input, speaker_input],
outputs=[audio_output, download_file, status_box],
show_progress="full"
)
clear_btn.click(
fn=clear_all,
inputs=[],
outputs=[text_input, speaker_input, audio_output, download_file, status_box]
)
demo.launch()