Spaces:
Running
Running
fix adapter save path, smart LM fallback, compact training UI, remove Server Status
Browse files- app.py +22 -35
- train_engine.py +6 -9
app.py
CHANGED
|
@@ -671,6 +671,9 @@ def gradio_main():
|
|
| 671 |
full_path, timeout=600,
|
| 672 |
cancel_check=lambda: _training_cancel.is_set(),
|
| 673 |
)
|
|
|
|
|
|
|
|
|
|
| 674 |
if caption_data:
|
| 675 |
bpm_s = caption_data.get("bpm", "?")
|
| 676 |
key_s = caption_data.get("keyscale", caption_data.get("key", "?"))
|
|
@@ -685,7 +688,7 @@ def gradio_main():
|
|
| 685 |
fallback = {"caption": "", "bpm": bpm_val, "key": "", "signature": "", "lyrics": ""}
|
| 686 |
with open(sidecar_json, "w") as cj:
|
| 687 |
json.dump(fallback, cj)
|
| 688 |
-
_log(f" {audio_fname}: librosa
|
| 689 |
except Exception as cap_exc:
|
| 690 |
_log(f" {audio_fname}: caption failed: {cap_exc}")
|
| 691 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
|
@@ -921,9 +924,7 @@ def gradio_main():
|
|
| 921 |
value=DEFAULT_LM, scale=1,
|
| 922 |
)
|
| 923 |
|
| 924 |
-
|
| 925 |
-
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
|
| 926 |
-
status_btn = gr.Button("Server Status", scale=1)
|
| 927 |
|
| 928 |
gen_btn.click(
|
| 929 |
fn=generate_music,
|
|
@@ -933,31 +934,23 @@ def gradio_main():
|
|
| 933 |
api_name="generate",
|
| 934 |
)
|
| 935 |
|
| 936 |
-
status_btn.click(
|
| 937 |
-
fn=get_server_status,
|
| 938 |
-
inputs=[],
|
| 939 |
-
outputs=[status],
|
| 940 |
-
api_name="server_status",
|
| 941 |
-
)
|
| 942 |
-
|
| 943 |
# ============================================================
|
| 944 |
# Tab 2: Train LoRA
|
| 945 |
# ============================================================
|
| 946 |
with gr.Tab("Train LoRA"):
|
| 947 |
-
gr.Markdown(
|
| 948 |
-
"### LoRA Training\n"
|
| 949 |
-
"Fine-tune ACE-Step on your audio. "
|
| 950 |
-
"CPU training is slow -- ace-server stops during training."
|
| 951 |
-
)
|
| 952 |
-
|
| 953 |
with gr.Row(elem_classes="compact-row"):
|
| 954 |
-
with gr.Column(scale=
|
| 955 |
-
|
| 956 |
-
label="Training
|
| 957 |
-
|
| 958 |
-
|
|
|
|
| 959 |
)
|
| 960 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 961 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 962 |
train_epochs = gr.Slider(
|
| 963 |
label="Epochs", minimum=1, maximum=1000,
|
|
@@ -968,18 +961,12 @@ def gradio_main():
|
|
| 968 |
label="Rank (r)", minimum=1, maximum=128,
|
| 969 |
value=32, step=1,
|
| 970 |
)
|
| 971 |
-
|
| 972 |
-
|
| 973 |
-
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
| 977 |
-
train_log = gr.Textbox(
|
| 978 |
-
label="Training Log",
|
| 979 |
-
interactive=False,
|
| 980 |
-
lines=10,
|
| 981 |
-
elem_classes="status-box",
|
| 982 |
-
)
|
| 983 |
|
| 984 |
# Button swap on click (separate handler, like rvc-beatrice)
|
| 985 |
# This fires immediately so user sees Cancel even if training
|
|
|
|
| 671 |
full_path, timeout=600,
|
| 672 |
cancel_check=lambda: _training_cancel.is_set(),
|
| 673 |
)
|
| 674 |
+
if not caption_data:
|
| 675 |
+
use_understand = False
|
| 676 |
+
_log(f" {audio_fname}: GGUF LM too slow, switching to librosa for remaining files")
|
| 677 |
if caption_data:
|
| 678 |
bpm_s = caption_data.get("bpm", "?")
|
| 679 |
key_s = caption_data.get("keyscale", caption_data.get("key", "?"))
|
|
|
|
| 688 |
fallback = {"caption": "", "bpm": bpm_val, "key": "", "signature": "", "lyrics": ""}
|
| 689 |
with open(sidecar_json, "w") as cj:
|
| 690 |
json.dump(fallback, cj)
|
| 691 |
+
_log(f" {audio_fname}: librosa BPM={bpm_val}")
|
| 692 |
except Exception as cap_exc:
|
| 693 |
_log(f" {audio_fname}: caption failed: {cap_exc}")
|
| 694 |
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
|
|
|
| 924 |
value=DEFAULT_LM, scale=1,
|
| 925 |
)
|
| 926 |
|
| 927 |
+
gen_btn = gr.Button("Generate Music", variant="primary")
|
|
|
|
|
|
|
| 928 |
|
| 929 |
gen_btn.click(
|
| 930 |
fn=generate_music,
|
|
|
|
| 934 |
api_name="generate",
|
| 935 |
)
|
| 936 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 937 |
# ============================================================
|
| 938 |
# Tab 2: Train LoRA
|
| 939 |
# ============================================================
|
| 940 |
with gr.Tab("Train LoRA"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
with gr.Row(elem_classes="compact-row"):
|
| 942 |
+
with gr.Column(scale=3):
|
| 943 |
+
train_log = gr.Textbox(
|
| 944 |
+
label="Training Log",
|
| 945 |
+
interactive=False,
|
| 946 |
+
lines=12,
|
| 947 |
+
elem_classes="status-box",
|
| 948 |
)
|
| 949 |
+
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
|
| 950 |
+
with gr.Column(scale=2):
|
| 951 |
+
with gr.Row(elem_classes="compact-row"):
|
| 952 |
+
train_btn = gr.Button("Train", variant="primary", scale=2)
|
| 953 |
+
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
| 954 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 955 |
train_epochs = gr.Slider(
|
| 956 |
label="Epochs", minimum=1, maximum=1000,
|
|
|
|
| 961 |
label="Rank (r)", minimum=1, maximum=128,
|
| 962 |
value=32, step=1,
|
| 963 |
)
|
| 964 |
+
train_audio = gr.File(
|
| 965 |
+
label="Training Audio (optional caption .txt)",
|
| 966 |
+
file_count="multiple",
|
| 967 |
+
file_types=["audio", ".txt", ".json"],
|
| 968 |
+
height=120,
|
| 969 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 970 |
|
| 971 |
# Button swap on click (separate handler, like rvc-beatrice)
|
| 972 |
# This fires immediately so user sees Cancel even if training
|
train_engine.py
CHANGED
|
@@ -2539,10 +2539,9 @@ def train_lora_generator(
|
|
| 2539 |
if _training_cancel.is_set():
|
| 2540 |
_training_cancel.clear()
|
| 2541 |
if epoch > start_epoch:
|
| 2542 |
-
early_path = str(out_path / "early_exit")
|
| 2543 |
model.decoder.eval()
|
| 2544 |
-
save_lora_adapter(model,
|
| 2545 |
-
yield f"[OK] Cancelled at epoch {epoch + 1},
|
| 2546 |
else:
|
| 2547 |
yield f"[CANCELLED] Stopped before any epoch completed"
|
| 2548 |
yield "[DONE]"
|
|
@@ -2554,10 +2553,9 @@ def train_lora_generator(
|
|
| 2554 |
# Timeout check
|
| 2555 |
elapsed = time.time() - train_start
|
| 2556 |
if elapsed > MAX_TRAINING_TIME:
|
| 2557 |
-
early_path = str(out_path / "timeout_exit")
|
| 2558 |
model.decoder.eval()
|
| 2559 |
-
save_lora_adapter(model,
|
| 2560 |
-
yield f"[WARN] Training timed out after {int(elapsed)}s,
|
| 2561 |
yield "[DONE]"
|
| 2562 |
_cuda_sync(device)
|
| 2563 |
unload_models(model)
|
|
@@ -2721,11 +2719,10 @@ def train_lora_generator(
|
|
| 2721 |
f"Loss: {avg_epoch_loss:.4f}{best_str}"
|
| 2722 |
)
|
| 2723 |
|
| 2724 |
-
# Save best
|
| 2725 |
if is_best and epoch + 1 >= 10:
|
| 2726 |
-
best_path = str(out_path / "best")
|
| 2727 |
model.decoder.eval()
|
| 2728 |
-
save_lora_adapter(model,
|
| 2729 |
model.decoder.train()
|
| 2730 |
yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})"
|
| 2731 |
|
|
|
|
| 2539 |
if _training_cancel.is_set():
|
| 2540 |
_training_cancel.clear()
|
| 2541 |
if epoch > start_epoch:
|
|
|
|
| 2542 |
model.decoder.eval()
|
| 2543 |
+
save_lora_adapter(model, str(out_path))
|
| 2544 |
+
yield f"[OK] Cancelled at epoch {epoch + 1}, adapter saved"
|
| 2545 |
else:
|
| 2546 |
yield f"[CANCELLED] Stopped before any epoch completed"
|
| 2547 |
yield "[DONE]"
|
|
|
|
| 2553 |
# Timeout check
|
| 2554 |
elapsed = time.time() - train_start
|
| 2555 |
if elapsed > MAX_TRAINING_TIME:
|
|
|
|
| 2556 |
model.decoder.eval()
|
| 2557 |
+
save_lora_adapter(model, str(out_path))
|
| 2558 |
+
yield f"[WARN] Training timed out after {int(elapsed)}s, adapter saved"
|
| 2559 |
yield "[DONE]"
|
| 2560 |
_cuda_sync(device)
|
| 2561 |
unload_models(model)
|
|
|
|
| 2719 |
f"Loss: {avg_epoch_loss:.4f}{best_str}"
|
| 2720 |
)
|
| 2721 |
|
| 2722 |
+
# Save best (directly to output dir so ace-server finds it)
|
| 2723 |
if is_best and epoch + 1 >= 10:
|
|
|
|
| 2724 |
model.decoder.eval()
|
| 2725 |
+
save_lora_adapter(model, str(out_path))
|
| 2726 |
model.decoder.train()
|
| 2727 |
yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})"
|
| 2728 |
|