Spaces:
Running
Running
run training as detached subprocess to survive Gradio session timeout
Browse files
app.py
CHANGED
|
@@ -285,15 +285,21 @@ def gradio_main():
|
|
| 285 |
lines.append(json.dumps(props, indent=2))
|
| 286 |
return "\n".join(lines)
|
| 287 |
|
| 288 |
-
# -- Training --
|
|
|
|
|
|
|
| 289 |
def train_lora(audio_files, lora_name, epochs, lr, rank,
|
| 290 |
progress=gr.Progress(track_tqdm=True)):
|
| 291 |
-
import shutil
|
| 292 |
-
import gc
|
| 293 |
|
| 294 |
if not audio_files:
|
| 295 |
return "No audio files uploaded."
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
lora_name = (lora_name or "").strip() or "my-lora"
|
| 298 |
epochs = max(1, min(int(epochs), 10))
|
| 299 |
lr = float(lr)
|
|
@@ -301,152 +307,126 @@ def gradio_main():
|
|
| 301 |
|
| 302 |
output_dir = os.path.join(ADAPTER_DIR, lora_name)
|
| 303 |
os.makedirs(output_dir, exist_ok=True)
|
| 304 |
-
|
| 305 |
audio_dir = os.path.join(output_dir, "audio_input")
|
| 306 |
os.makedirs(audio_dir, exist_ok=True)
|
| 307 |
for f in audio_files:
|
| 308 |
src = f.name if hasattr(f, "name") else str(f)
|
| 309 |
shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
|
| 310 |
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
|
|
|
| 384 |
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
)
|
| 391 |
-
model = model.bfloat16()
|
| 392 |
-
|
| 393 |
-
adapter_cfg = LoRAConfigV2(r=rank, alpha=rank, dropout=0.0)
|
| 394 |
-
train_cfg = TrainingConfigV2(
|
| 395 |
-
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 396 |
-
model_variant="turbo",
|
| 397 |
-
dataset_dir=tensor_dir,
|
| 398 |
-
output_dir=output_dir,
|
| 399 |
-
max_epochs=epochs,
|
| 400 |
-
batch_size=1,
|
| 401 |
-
learning_rate=lr,
|
| 402 |
-
device="cpu",
|
| 403 |
-
precision="bfloat16",
|
| 404 |
-
seed=42,
|
| 405 |
-
num_workers=0,
|
| 406 |
-
pin_memory=False,
|
| 407 |
-
)
|
| 408 |
-
|
| 409 |
-
trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg)
|
| 410 |
-
|
| 411 |
-
step_count = 0
|
| 412 |
-
last_loss = 0.0
|
| 413 |
-
for update in trainer.train():
|
| 414 |
-
if hasattr(update, "step"):
|
| 415 |
-
step_count = update.step
|
| 416 |
-
last_loss = update.loss
|
| 417 |
-
elif isinstance(update, tuple) and len(update) >= 2:
|
| 418 |
-
step_count = update[0]
|
| 419 |
-
last_loss = update[1]
|
| 420 |
-
if step_count % 5 == 0:
|
| 421 |
-
log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
|
| 422 |
-
pct = 0.30 + 0.65 * min(step_count / max(epochs * processed, 1), 1.0)
|
| 423 |
-
progress(pct, desc=f"Step {step_count}, loss={last_loss:.4f}")
|
| 424 |
-
|
| 425 |
-
_log(f"Training complete! Final: step {step_count}, loss={last_loss:.4f}")
|
| 426 |
-
_log(f"LoRA saved to: {output_dir}")
|
| 427 |
-
|
| 428 |
-
del model, trainer
|
| 429 |
-
gc.collect()
|
| 430 |
-
|
| 431 |
-
except ImportError as e:
|
| 432 |
-
_log(f"Import error: {e}")
|
| 433 |
-
_log(f"Check ACE-Step source at {ACE_SOURCE_DIR}")
|
| 434 |
-
import traceback
|
| 435 |
-
log_lines.append(traceback.format_exc())
|
| 436 |
-
except Exception as e:
|
| 437 |
-
import traceback
|
| 438 |
-
_log(f"ERROR: {e}")
|
| 439 |
-
log_lines.append(traceback.format_exc())
|
| 440 |
-
finally:
|
| 441 |
-
_log("Restarting ace-server...")
|
| 442 |
-
import subprocess
|
| 443 |
-
subprocess.Popen([
|
| 444 |
-
"/app/ace-server", "--host", "127.0.0.1", "--port", "8085",
|
| 445 |
-
"--models", "/app/models", "--adapters", "/app/adapters",
|
| 446 |
-
"--max-batch", "1",
|
| 447 |
-
], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 448 |
-
|
| 449 |
-
return "\n".join(log_lines)
|
| 450 |
|
| 451 |
# -- Build UI --
|
| 452 |
CSS = """
|
|
@@ -548,11 +528,13 @@ def gradio_main():
|
|
| 548 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 549 |
rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
|
| 550 |
|
| 551 |
-
|
|
|
|
|
|
|
| 552 |
train_log = gr.Textbox(
|
| 553 |
label="Training Log",
|
| 554 |
interactive=False,
|
| 555 |
-
lines=
|
| 556 |
elem_classes="status-box",
|
| 557 |
)
|
| 558 |
|
|
@@ -562,6 +544,12 @@ def gradio_main():
|
|
| 562 |
outputs=[train_log],
|
| 563 |
api_name="train_lora",
|
| 564 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 565 |
|
| 566 |
demo.launch(
|
| 567 |
server_name="0.0.0.0",
|
|
|
|
| 285 |
lines.append(json.dumps(props, indent=2))
|
| 286 |
return "\n".join(lines)
|
| 287 |
|
| 288 |
+
# -- Training (runs as detached subprocess to survive Gradio session timeout) --
|
| 289 |
+
TRAIN_LOG = "/app/outputs/train.log"
|
| 290 |
+
|
| 291 |
def train_lora(audio_files, lora_name, epochs, lr, rank,
|
| 292 |
progress=gr.Progress(track_tqdm=True)):
|
| 293 |
+
import shutil, subprocess
|
|
|
|
| 294 |
|
| 295 |
if not audio_files:
|
| 296 |
return "No audio files uploaded."
|
| 297 |
|
| 298 |
+
if os.path.exists(TRAIN_LOG):
|
| 299 |
+
last_line = open(TRAIN_LOG).readlines()[-1] if os.path.getsize(TRAIN_LOG) > 0 else ""
|
| 300 |
+
if "DONE" not in last_line and "ERROR" not in last_line and last_line.strip():
|
| 301 |
+
return f"Training already in progress. Click 'Check Log' to monitor.\n\nLast: {last_line.strip()}"
|
| 302 |
+
|
| 303 |
lora_name = (lora_name or "").strip() or "my-lora"
|
| 304 |
epochs = max(1, min(int(epochs), 10))
|
| 305 |
lr = float(lr)
|
|
|
|
| 307 |
|
| 308 |
output_dir = os.path.join(ADAPTER_DIR, lora_name)
|
| 309 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
|
| 310 |
audio_dir = os.path.join(output_dir, "audio_input")
|
| 311 |
os.makedirs(audio_dir, exist_ok=True)
|
| 312 |
for f in audio_files:
|
| 313 |
src = f.name if hasattr(f, "name") else str(f)
|
| 314 |
shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
|
| 315 |
|
| 316 |
+
train_script = f"""
|
| 317 |
+
import os, sys, time, gc
|
| 318 |
+
sys.path.insert(0, "{ACE_SOURCE_DIR}")
|
| 319 |
+
os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "1"
|
| 320 |
+
|
| 321 |
+
LOG = "{TRAIN_LOG}"
|
| 322 |
+
def log(msg):
|
| 323 |
+
print(f"[train] {{msg}}", flush=True)
|
| 324 |
+
with open(LOG, "a") as f:
|
| 325 |
+
f.write(msg + "\\n")
|
| 326 |
+
f.flush()
|
| 327 |
+
|
| 328 |
+
open(LOG, "w").close()
|
| 329 |
+
log("LoRA Training: '{lora_name}' | files={len(audio_files)} | epochs={epochs} lr={lr} rank={rank}")
|
| 330 |
+
|
| 331 |
+
import subprocess
|
| 332 |
+
log("Stopping ace-server...")
|
| 333 |
+
subprocess.run(["pkill", "-f", "ace-server"], stderr=subprocess.DEVNULL)
|
| 334 |
+
time.sleep(2)
|
| 335 |
+
gc.collect()
|
| 336 |
+
|
| 337 |
+
try:
|
| 338 |
+
import torchaudio
|
| 339 |
+
_orig = torchaudio.load
|
| 340 |
+
def _sf(p, *a, **kw):
|
| 341 |
+
kw.setdefault("backend", "soundfile")
|
| 342 |
+
return _orig(p, *a, **kw)
|
| 343 |
+
torchaudio.load = _sf
|
| 344 |
+
|
| 345 |
+
log("[Step 1/2] Preprocessing audio...")
|
| 346 |
+
from acestep.training_v2.preprocess import preprocess_audio_files
|
| 347 |
+
result = preprocess_audio_files(
|
| 348 |
+
audio_dir="{audio_dir}",
|
| 349 |
+
output_dir="{output_dir}/preprocessed_tensors",
|
| 350 |
+
checkpoint_dir="{ACE_CHECKPOINT_DIR}",
|
| 351 |
+
variant="turbo", max_duration=60.0,
|
| 352 |
+
device="cpu", precision="bfloat16",
|
| 353 |
+
)
|
| 354 |
+
processed = result.get("processed", 0)
|
| 355 |
+
failed = result.get("failed", 0)
|
| 356 |
+
log(f" Preprocessed: {{processed}}/{{result.get('total',0)}} (failed: {{failed}})")
|
| 357 |
+
if processed == 0:
|
| 358 |
+
log("ERROR: No files preprocessed. DONE")
|
| 359 |
+
raise SystemExit(1)
|
| 360 |
+
|
| 361 |
+
gc.collect()
|
| 362 |
+
log("[Step 2/2] Training LoRA...")
|
| 363 |
+
from acestep.training_v2.model_loader import load_decoder_for_training
|
| 364 |
+
from acestep.training_v2.trainer_fixed import FixedLoRATrainer
|
| 365 |
+
from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
|
| 366 |
+
|
| 367 |
+
model = load_decoder_for_training(
|
| 368 |
+
checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
|
| 369 |
+
device="cpu", precision="bfloat16",
|
| 370 |
+
).bfloat16()
|
| 371 |
+
|
| 372 |
+
trainer = FixedLoRATrainer(model,
|
| 373 |
+
LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
|
| 374 |
+
TrainingConfigV2(
|
| 375 |
+
checkpoint_dir="{ACE_CHECKPOINT_DIR}", model_variant="turbo",
|
| 376 |
+
dataset_dir="{output_dir}/preprocessed_tensors",
|
| 377 |
+
output_dir="{output_dir}",
|
| 378 |
+
max_epochs={epochs}, batch_size=1, learning_rate={lr},
|
| 379 |
+
device="cpu", precision="bfloat16", seed=42,
|
| 380 |
+
num_workers=0, pin_memory=False,
|
| 381 |
+
))
|
| 382 |
+
|
| 383 |
+
step_count, last_loss = 0, 0.0
|
| 384 |
+
for update in trainer.train():
|
| 385 |
+
if hasattr(update, "step"):
|
| 386 |
+
step_count, last_loss = update.step, update.loss
|
| 387 |
+
elif isinstance(update, tuple) and len(update) >= 2:
|
| 388 |
+
step_count, last_loss = update[0], update[1]
|
| 389 |
+
if step_count % 5 == 0:
|
| 390 |
+
log(f" Step {{step_count}}: loss={{last_loss:.4f}}")
|
| 391 |
+
|
| 392 |
+
log(f"Training complete! step={{step_count}} loss={{last_loss:.4f}}")
|
| 393 |
+
log(f"LoRA saved to: {output_dir}")
|
| 394 |
+
del model, trainer
|
| 395 |
+
gc.collect()
|
| 396 |
+
log("DONE")
|
| 397 |
+
|
| 398 |
+
except Exception as e:
|
| 399 |
+
import traceback
|
| 400 |
+
log(f"ERROR: {{e}}")
|
| 401 |
+
log(traceback.format_exc())
|
| 402 |
+
log("DONE")
|
| 403 |
+
finally:
|
| 404 |
+
log("Restarting ace-server...")
|
| 405 |
+
subprocess.Popen(["/app/ace-server", "--host", "127.0.0.1", "--port", "8085",
|
| 406 |
+
"--models", "/app/models", "--adapters", "/app/adapters", "--max-batch", "1"],
|
| 407 |
+
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
| 408 |
+
"""
|
| 409 |
+
script_path = os.path.join(output_dir, "_train.py")
|
| 410 |
+
with open(script_path, "w") as f:
|
| 411 |
+
f.write(train_script)
|
| 412 |
+
|
| 413 |
+
subprocess.Popen(
|
| 414 |
+
["python3", script_path],
|
| 415 |
+
stdout=open("/dev/null", "w"),
|
| 416 |
+
stderr=open("/dev/null", "w"),
|
| 417 |
+
start_new_session=True,
|
| 418 |
+
)
|
| 419 |
|
| 420 |
+
return (f"Training started in background for '{lora_name}'.\n"
|
| 421 |
+
f"Audio: {len(audio_files)} files, Epochs: {epochs}, Rank: {rank}\n\n"
|
| 422 |
+
f"Click 'Check Log' to monitor progress.\n"
|
| 423 |
+
f"Inference will be unavailable until training completes (ace-server stopped).")
|
| 424 |
|
| 425 |
+
def check_train_log():
|
| 426 |
+
if not os.path.exists(TRAIN_LOG):
|
| 427 |
+
return "No training log found."
|
| 428 |
+
with open(TRAIN_LOG) as f:
|
| 429 |
+
return f.read() or "Log is empty."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
|
| 431 |
# -- Build UI --
|
| 432 |
CSS = """
|
|
|
|
| 528 |
lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 529 |
rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
|
| 530 |
|
| 531 |
+
with gr.Row(elem_classes="compact-row"):
|
| 532 |
+
train_btn = gr.Button("Train", variant="primary", scale=2)
|
| 533 |
+
log_btn = gr.Button("Check Log", scale=1)
|
| 534 |
train_log = gr.Textbox(
|
| 535 |
label="Training Log",
|
| 536 |
interactive=False,
|
| 537 |
+
lines=12,
|
| 538 |
elem_classes="status-box",
|
| 539 |
)
|
| 540 |
|
|
|
|
| 544 |
outputs=[train_log],
|
| 545 |
api_name="train_lora",
|
| 546 |
)
|
| 547 |
+
log_btn.click(
|
| 548 |
+
fn=check_train_log,
|
| 549 |
+
inputs=[],
|
| 550 |
+
outputs=[train_log],
|
| 551 |
+
api_name="check_train_log",
|
| 552 |
+
)
|
| 553 |
|
| 554 |
demo.launch(
|
| 555 |
server_name="0.0.0.0",
|