Nekochu commited on
Commit
5c2e4e7
·
1 Parent(s): 882ed5c

add LoRA adapter dropdown to inference UI

Browse files
Files changed (1) hide show
  1. app.py +13 -2
app.py CHANGED
@@ -234,8 +234,16 @@ def gradio_main():
234
  import gradio as gr
235
 
236
  # -- Generate tab handler --
 
 
 
 
 
 
 
 
237
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
238
- steps, progress=gr.Progress(track_tqdm=True)):
239
  if not _server_ok():
240
  return None, "ace-server not running. Check logs."
241
 
@@ -243,6 +251,7 @@ def gradio_main():
243
  lyrics = "[Instrumental]"
244
 
245
  actual_seed = None if seed is None or int(seed) < 0 else int(seed)
 
246
 
247
  progress_map = {
248
  "lm_submit": (0.05, "Submitting LM job..."),
@@ -267,6 +276,7 @@ def gradio_main():
267
  seed=actual_seed,
268
  steps=steps,
269
  output_format="mp3",
 
270
  progress_cb=gr_progress,
271
  )
272
  return audio_path, status
@@ -502,6 +512,7 @@ finally:
502
  value=8, step=1, scale=1,
503
  )
504
  seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
 
505
 
506
  with gr.Row(elem_classes="compact-row"):
507
  gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
@@ -510,7 +521,7 @@ finally:
510
  gen_btn.click(
511
  fn=generate_music,
512
  inputs=[caption, lyrics, instrumental, bpm, duration,
513
- seed, steps],
514
  outputs=[audio_out, status],
515
  api_name="generate",
516
  )
 
234
  import gradio as gr
235
 
236
  # -- Generate tab handler --
237
+ def get_trained_loras():
238
+ loras = ["None (no LoRA)"]
239
+ if os.path.isdir(ADAPTER_DIR):
240
+ for d in os.listdir(ADAPTER_DIR):
241
+ if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
242
+ loras.append(d)
243
+ return loras
244
+
245
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
246
+ steps, lora_select, progress=gr.Progress(track_tqdm=True)):
247
  if not _server_ok():
248
  return None, "ace-server not running. Check logs."
249
 
 
251
  lyrics = "[Instrumental]"
252
 
253
  actual_seed = None if seed is None or int(seed) < 0 else int(seed)
254
+ adapter = None if lora_select == "None (no LoRA)" else lora_select
255
 
256
  progress_map = {
257
  "lm_submit": (0.05, "Submitting LM job..."),
 
276
  seed=actual_seed,
277
  steps=steps,
278
  output_format="mp3",
279
+ adapter=adapter,
280
  progress_cb=gr_progress,
281
  )
282
  return audio_path, status
 
512
  value=8, step=1, scale=1,
513
  )
514
  seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
515
+ lora_select = gr.Dropdown(label="LoRA", choices=get_trained_loras(), value="None (no LoRA)", scale=1)
516
 
517
  with gr.Row(elem_classes="compact-row"):
518
  gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
 
521
  gen_btn.click(
522
  fn=generate_music,
523
  inputs=[caption, lyrics, instrumental, bpm, duration,
524
+ seed, steps, lora_select],
525
  outputs=[audio_out, status],
526
  api_name="generate",
527
  )