Spaces:
Running
Running
add LoRA adapter dropdown to inference UI
Browse files
app.py
CHANGED
|
@@ -234,8 +234,16 @@ def gradio_main():
|
|
| 234 |
import gradio as gr
|
| 235 |
|
| 236 |
# -- Generate tab handler --
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 238 |
-
steps, progress=gr.Progress(track_tqdm=True)):
|
| 239 |
if not _server_ok():
|
| 240 |
return None, "ace-server not running. Check logs."
|
| 241 |
|
|
@@ -243,6 +251,7 @@ def gradio_main():
|
|
| 243 |
lyrics = "[Instrumental]"
|
| 244 |
|
| 245 |
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
|
|
|
|
| 246 |
|
| 247 |
progress_map = {
|
| 248 |
"lm_submit": (0.05, "Submitting LM job..."),
|
|
@@ -267,6 +276,7 @@ def gradio_main():
|
|
| 267 |
seed=actual_seed,
|
| 268 |
steps=steps,
|
| 269 |
output_format="mp3",
|
|
|
|
| 270 |
progress_cb=gr_progress,
|
| 271 |
)
|
| 272 |
return audio_path, status
|
|
@@ -502,6 +512,7 @@ finally:
|
|
| 502 |
value=8, step=1, scale=1,
|
| 503 |
)
|
| 504 |
seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
|
|
|
|
| 505 |
|
| 506 |
with gr.Row(elem_classes="compact-row"):
|
| 507 |
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
|
|
@@ -510,7 +521,7 @@ finally:
|
|
| 510 |
gen_btn.click(
|
| 511 |
fn=generate_music,
|
| 512 |
inputs=[caption, lyrics, instrumental, bpm, duration,
|
| 513 |
-
seed, steps],
|
| 514 |
outputs=[audio_out, status],
|
| 515 |
api_name="generate",
|
| 516 |
)
|
|
|
|
| 234 |
import gradio as gr
|
| 235 |
|
| 236 |
# -- Generate tab handler --
|
| 237 |
+
def get_trained_loras():
|
| 238 |
+
loras = ["None (no LoRA)"]
|
| 239 |
+
if os.path.isdir(ADAPTER_DIR):
|
| 240 |
+
for d in os.listdir(ADAPTER_DIR):
|
| 241 |
+
if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
|
| 242 |
+
loras.append(d)
|
| 243 |
+
return loras
|
| 244 |
+
|
| 245 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 246 |
+
steps, lora_select, progress=gr.Progress(track_tqdm=True)):
|
| 247 |
if not _server_ok():
|
| 248 |
return None, "ace-server not running. Check logs."
|
| 249 |
|
|
|
|
| 251 |
lyrics = "[Instrumental]"
|
| 252 |
|
| 253 |
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
|
| 254 |
+
adapter = None if lora_select == "None (no LoRA)" else lora_select
|
| 255 |
|
| 256 |
progress_map = {
|
| 257 |
"lm_submit": (0.05, "Submitting LM job..."),
|
|
|
|
| 276 |
seed=actual_seed,
|
| 277 |
steps=steps,
|
| 278 |
output_format="mp3",
|
| 279 |
+
adapter=adapter,
|
| 280 |
progress_cb=gr_progress,
|
| 281 |
)
|
| 282 |
return audio_path, status
|
|
|
|
| 512 |
value=8, step=1, scale=1,
|
| 513 |
)
|
| 514 |
seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
|
| 515 |
+
lora_select = gr.Dropdown(label="LoRA", choices=get_trained_loras(), value="None (no LoRA)", scale=1)
|
| 516 |
|
| 517 |
with gr.Row(elem_classes="compact-row"):
|
| 518 |
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
|
|
|
|
| 521 |
gen_btn.click(
|
| 522 |
fn=generate_music,
|
| 523 |
inputs=[caption, lyrics, instrumental, bpm, duration,
|
| 524 |
+
seed, steps, lora_select],
|
| 525 |
outputs=[audio_out, status],
|
| 526 |
api_name="generate",
|
| 527 |
)
|