Spaces:
Running
Running
Side-Step training engine, tested locally on CPU
Browse files- train_engine.py: 1347-line standalone LoRA training (ported from Side-Step)
- Adafactor optimizer, 2-pass sequential preprocessing, FA->SDPA->eager fallback
- Generator-based progress (yield per epoch), cancel button, checkpoint saves
- Pin transformers<4.58.0 (meta tensor fix), skip device_map on CPU
- Tested locally: preprocessing 174s + training 51s (2 epochs, 10s audio)
- Dockerfile +1 -1
- app.py +333 -168
- train_engine.py +1353 -0
Dockerfile
CHANGED
|
@@ -72,7 +72,7 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
|
|
| 72 |
# Install Python deps for Gradio UI + training
|
| 73 |
RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 74 |
"gradio[mcp]==5.29.0" requests torch safetensors \
|
| 75 |
-
transformers>=4.51.0 peft>=0.18.0 \
|
| 76 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 77 |
einops vector_quantize_pytorch
|
| 78 |
|
|
|
|
| 72 |
# Install Python deps for Gradio UI + training
|
| 73 |
RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 74 |
"gradio[mcp]==5.29.0" requests torch safetensors \
|
| 75 |
+
"transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
|
| 76 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 77 |
einops vector_quantize_pytorch
|
| 78 |
|
app.py
CHANGED
|
@@ -6,7 +6,31 @@ import time
|
|
| 6 |
import json
|
| 7 |
import argparse
|
| 8 |
import tempfile
|
|
|
|
|
|
|
| 9 |
import requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
|
| 12 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
|
@@ -16,6 +40,12 @@ ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
|
|
| 16 |
ACE_SOURCE_DIR = "/app/ace-step-source"
|
| 17 |
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
|
| 18 |
ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# ---------------------------------------------------------------------------
|
| 21 |
# ace-server helpers
|
|
@@ -68,7 +98,7 @@ def _fetch_result(job_id, timeout=60):
|
|
| 68 |
|
| 69 |
|
| 70 |
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
| 71 |
-
adapter=None, progress_cb=None):
|
| 72 |
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
|
| 73 |
t0 = time.time()
|
| 74 |
|
|
@@ -86,6 +116,8 @@ def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
|
| 86 |
req["inference_steps"] = int(steps)
|
| 87 |
if adapter:
|
| 88 |
req["adapter"] = adapter
|
|
|
|
|
|
|
| 89 |
|
| 90 |
fmt = output_format if output_format in ("wav", "mp3") else "mp3"
|
| 91 |
synth_fmt = "wav16" if fmt == "wav" else "mp3"
|
|
@@ -143,6 +175,98 @@ def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
|
| 143 |
return tmp.name, msg
|
| 144 |
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
# ---------------------------------------------------------------------------
|
| 147 |
# CLI mode
|
| 148 |
# ---------------------------------------------------------------------------
|
|
@@ -216,7 +340,6 @@ def cli_main():
|
|
| 216 |
|
| 217 |
# Move to requested output path if specified
|
| 218 |
if args.output:
|
| 219 |
-
import shutil
|
| 220 |
out_dir = os.path.dirname(os.path.abspath(args.output))
|
| 221 |
os.makedirs(out_dir, exist_ok=True)
|
| 222 |
shutil.move(audio_path, args.output)
|
|
@@ -232,18 +355,15 @@ def cli_main():
|
|
| 232 |
|
| 233 |
def gradio_main():
|
| 234 |
import gradio as gr
|
|
|
|
| 235 |
|
| 236 |
-
# --
|
| 237 |
-
|
| 238 |
-
loras = ["None (no LoRA)"]
|
| 239 |
-
if os.path.isdir(ADAPTER_DIR):
|
| 240 |
-
for d in os.listdir(ADAPTER_DIR):
|
| 241 |
-
if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
|
| 242 |
-
loras.append(d)
|
| 243 |
-
return loras
|
| 244 |
|
|
|
|
| 245 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 246 |
-
steps, lora_select,
|
|
|
|
| 247 |
if not _server_ok():
|
| 248 |
return None, "ace-server not running. Check logs."
|
| 249 |
|
|
@@ -252,6 +372,7 @@ def gradio_main():
|
|
| 252 |
|
| 253 |
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
|
| 254 |
adapter = None if lora_select == "None (no LoRA)" else lora_select
|
|
|
|
| 255 |
|
| 256 |
progress_map = {
|
| 257 |
"lm_submit": (0.05, "Submitting LM job..."),
|
|
@@ -277,6 +398,7 @@ def gradio_main():
|
|
| 277 |
steps=steps,
|
| 278 |
output_format="mp3",
|
| 279 |
adapter=adapter,
|
|
|
|
| 280 |
progress_cb=gr_progress,
|
| 281 |
)
|
| 282 |
return audio_path, status
|
|
@@ -295,22 +417,35 @@ def gradio_main():
|
|
| 295 |
lines.append(json.dumps(props, indent=2))
|
| 296 |
return "\n".join(lines)
|
| 297 |
|
| 298 |
-
# -- Training (
|
| 299 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
import shutil, subprocess
|
| 304 |
|
|
|
|
| 305 |
if not audio_files:
|
| 306 |
-
|
|
|
|
|
|
|
| 307 |
|
| 308 |
-
if
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
|
| 313 |
lora_name = (lora_name or "").strip() or "my-lora"
|
|
|
|
|
|
|
|
|
|
| 314 |
epochs = max(1, min(int(epochs), 10))
|
| 315 |
lr = float(lr)
|
| 316 |
rank = max(1, min(int(rank), 64))
|
|
@@ -319,145 +454,136 @@ def gradio_main():
|
|
| 319 |
os.makedirs(output_dir, exist_ok=True)
|
| 320 |
audio_dir = os.path.join(output_dir, "audio_input")
|
| 321 |
os.makedirs(audio_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
for f in audio_files:
|
| 323 |
src = f.name if hasattr(f, "name") else str(f)
|
| 324 |
shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
|
| 325 |
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "1"
|
| 330 |
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
f.flush()
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
|
| 433 |
-
|
| 434 |
-
""
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
f"Click 'Check Log' to monitor progress.\n"
|
| 450 |
-
f"Inference will be unavailable until training completes (ace-server stopped).")
|
| 451 |
-
|
| 452 |
-
def check_train_log():
|
| 453 |
-
parts = []
|
| 454 |
-
if os.path.exists(TRAIN_LOG):
|
| 455 |
-
parts.append(open(TRAIN_LOG).read())
|
| 456 |
-
stderr_log = os.path.join(ADAPTER_DIR, "test-lora", "train_stderr.log")
|
| 457 |
-
if os.path.exists(stderr_log) and os.path.getsize(stderr_log) > 0:
|
| 458 |
-
stderr = open(stderr_log).read()[-8000:]
|
| 459 |
-
parts.append(f"\n--- stderr (last 8000 chars, {os.path.getsize(stderr_log)} bytes total) ---\n{stderr}")
|
| 460 |
-
return "\n".join(parts) if parts else "No training log found."
|
| 461 |
|
| 462 |
# -- Build UI --
|
| 463 |
CSS = """
|
|
@@ -512,7 +638,16 @@ finally:
|
|
| 512 |
value=8, step=1, scale=1,
|
| 513 |
)
|
| 514 |
seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
|
| 515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
|
| 517 |
with gr.Row(elem_classes="compact-row"):
|
| 518 |
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
|
|
@@ -521,7 +656,7 @@ finally:
|
|
| 521 |
gen_btn.click(
|
| 522 |
fn=generate_music,
|
| 523 |
inputs=[caption, lyrics, instrumental, bpm, duration,
|
| 524 |
-
seed, steps, lora_select],
|
| 525 |
outputs=[audio_out, status],
|
| 526 |
api_name="generate",
|
| 527 |
)
|
|
@@ -539,8 +674,8 @@ finally:
|
|
| 539 |
with gr.Tab("Train LoRA"):
|
| 540 |
gr.Markdown(
|
| 541 |
"### LoRA Training\n"
|
| 542 |
-
"Fine-tune ACE-Step on your
|
| 543 |
-
"CPU training is
|
| 544 |
)
|
| 545 |
|
| 546 |
with gr.Row(elem_classes="compact-row"):
|
|
@@ -552,13 +687,21 @@ finally:
|
|
| 552 |
)
|
| 553 |
with gr.Column(scale=1):
|
| 554 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
|
| 559 |
with gr.Row(elem_classes="compact-row"):
|
| 560 |
train_btn = gr.Button("Train", variant="primary", scale=2)
|
|
|
|
| 561 |
log_btn = gr.Button("Check Log", scale=1)
|
|
|
|
| 562 |
train_log = gr.Textbox(
|
| 563 |
label="Training Log",
|
| 564 |
interactive=False,
|
|
@@ -566,17 +709,39 @@ finally:
|
|
| 566 |
elem_classes="status-box",
|
| 567 |
)
|
| 568 |
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
|
|
|
| 573 |
api_name="train_lora",
|
|
|
|
| 574 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 575 |
log_btn.click(
|
| 576 |
-
|
| 577 |
-
inputs=[],
|
| 578 |
outputs=[train_log],
|
| 579 |
-
api_name="
|
| 580 |
)
|
| 581 |
|
| 582 |
demo.launch(
|
|
|
|
| 6 |
import json
|
| 7 |
import argparse
|
| 8 |
import tempfile
|
| 9 |
+
import subprocess
|
| 10 |
+
import shutil
|
| 11 |
import requests
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from train_engine import (
|
| 15 |
+
preprocess_audio,
|
| 16 |
+
train_lora_generator,
|
| 17 |
+
cancel_training,
|
| 18 |
+
get_trained_loras as _get_trained_loras_engine,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
# Configurable limits (edit here, not buried in code)
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
|
| 27 |
+
MAX_AUDIO_DURATION = 240 # seconds, cap per audio file for training
|
| 28 |
+
MAX_TRAINING_TIME = 28800 # 8 hours hard training timeout (seconds)
|
| 29 |
+
MAX_AUDIO_FILES = 50 # max number of training audio files per run
|
| 30 |
+
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
# Paths & constants
|
| 33 |
+
# ---------------------------------------------------------------------------
|
| 34 |
|
| 35 |
ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
|
| 36 |
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
|
|
|
|
| 40 |
ACE_SOURCE_DIR = "/app/ace-step-source"
|
| 41 |
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
|
| 42 |
ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
|
| 43 |
+
MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models")
|
| 44 |
+
|
| 45 |
+
ACE_SERVER_BIN = "/app/ace-server"
|
| 46 |
+
|
| 47 |
+
# HF repo for on-demand GGUF downloads
|
| 48 |
+
GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
|
| 49 |
|
| 50 |
# ---------------------------------------------------------------------------
|
| 51 |
# ace-server helpers
|
|
|
|
| 98 |
|
| 99 |
|
| 100 |
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
| 101 |
+
adapter=None, lm_model=None, progress_cb=None):
|
| 102 |
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
|
| 103 |
t0 = time.time()
|
| 104 |
|
|
|
|
| 116 |
req["inference_steps"] = int(steps)
|
| 117 |
if adapter:
|
| 118 |
req["adapter"] = adapter
|
| 119 |
+
if lm_model:
|
| 120 |
+
req["model"] = lm_model
|
| 121 |
|
| 122 |
fmt = output_format if output_format in ("wav", "mp3") else "mp3"
|
| 123 |
synth_fmt = "wav16" if fmt == "wav" else "mp3"
|
|
|
|
| 175 |
return tmp.name, msg
|
| 176 |
|
| 177 |
|
| 178 |
+
# ---------------------------------------------------------------------------
|
| 179 |
+
# LM model scanning & on-demand download
|
| 180 |
+
# ---------------------------------------------------------------------------
|
| 181 |
+
|
| 182 |
+
def _scan_lm_models():
|
| 183 |
+
"""Scan /app/models for *-lm-*.gguf files, return list of filenames."""
|
| 184 |
+
models = []
|
| 185 |
+
if os.path.isdir(MODELS_DIR):
|
| 186 |
+
for f in sorted(os.listdir(MODELS_DIR)):
|
| 187 |
+
if "-lm-" in f and f.endswith(".gguf"):
|
| 188 |
+
models.append(f)
|
| 189 |
+
return models
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _download_lm_model(filename):
|
| 193 |
+
"""Download a GGUF LM model from HF if not already present."""
|
| 194 |
+
dest = os.path.join(MODELS_DIR, filename)
|
| 195 |
+
if os.path.isfile(dest):
|
| 196 |
+
return dest
|
| 197 |
+
try:
|
| 198 |
+
from huggingface_hub import hf_hub_download
|
| 199 |
+
path = hf_hub_download(
|
| 200 |
+
repo_id=GGUF_HF_REPO,
|
| 201 |
+
filename=filename,
|
| 202 |
+
local_dir=MODELS_DIR,
|
| 203 |
+
)
|
| 204 |
+
return path
|
| 205 |
+
except Exception as exc:
|
| 206 |
+
logger.error("Failed to download %s: %s", filename, exc)
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# ---------------------------------------------------------------------------
|
| 211 |
+
# LoRA listing for UI dropdowns
|
| 212 |
+
# ---------------------------------------------------------------------------
|
| 213 |
+
|
| 214 |
+
def _list_lora_choices():
|
| 215 |
+
"""Return list of LoRA choices for dropdown, including 'None'."""
|
| 216 |
+
choices = ["None (no LoRA)"]
|
| 217 |
+
if os.path.isdir(ADAPTER_DIR):
|
| 218 |
+
for d in os.listdir(ADAPTER_DIR):
|
| 219 |
+
if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
|
| 220 |
+
choices.append(d)
|
| 221 |
+
return choices
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# ---------------------------------------------------------------------------
|
| 225 |
+
# ace-server stop/start helpers
|
| 226 |
+
# ---------------------------------------------------------------------------
|
| 227 |
+
|
| 228 |
+
_ace_proc = None
|
| 229 |
+
|
| 230 |
+
def _stop_ace_server():
|
| 231 |
+
"""Stop ace-server process."""
|
| 232 |
+
global _ace_proc
|
| 233 |
+
if _ace_proc and _ace_proc.poll() is None:
|
| 234 |
+
_ace_proc.terminate()
|
| 235 |
+
try:
|
| 236 |
+
_ace_proc.wait(timeout=10)
|
| 237 |
+
except subprocess.TimeoutExpired:
|
| 238 |
+
_ace_proc.kill()
|
| 239 |
+
_ace_proc = None
|
| 240 |
+
else:
|
| 241 |
+
try:
|
| 242 |
+
subprocess.run(["pkill", "ace-server"],
|
| 243 |
+
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
|
| 244 |
+
timeout=10)
|
| 245 |
+
except Exception:
|
| 246 |
+
pass
|
| 247 |
+
time.sleep(1)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def _start_ace_server():
|
| 251 |
+
"""Start ace-server in background and wait for health."""
|
| 252 |
+
global _ace_proc
|
| 253 |
+
try:
|
| 254 |
+
_ace_proc = subprocess.Popen(
|
| 255 |
+
[ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
|
| 256 |
+
"--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
|
| 257 |
+
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
| 258 |
+
)
|
| 259 |
+
except Exception as exc:
|
| 260 |
+
logger.error("Failed to start ace-server: %s", exc)
|
| 261 |
+
return False
|
| 262 |
+
|
| 263 |
+
for _ in range(30):
|
| 264 |
+
if _server_ok():
|
| 265 |
+
return True
|
| 266 |
+
time.sleep(2)
|
| 267 |
+
return False
|
| 268 |
+
|
| 269 |
+
|
| 270 |
# ---------------------------------------------------------------------------
|
| 271 |
# CLI mode
|
| 272 |
# ---------------------------------------------------------------------------
|
|
|
|
| 340 |
|
| 341 |
# Move to requested output path if specified
|
| 342 |
if args.output:
|
|
|
|
| 343 |
out_dir = os.path.dirname(os.path.abspath(args.output))
|
| 344 |
os.makedirs(out_dir, exist_ok=True)
|
| 345 |
shutil.move(audio_path, args.output)
|
|
|
|
| 355 |
|
| 356 |
def gradio_main():
|
| 357 |
import gradio as gr
|
| 358 |
+
import gc
|
| 359 |
|
| 360 |
+
# -- Persistent training log buffer (survives across yields) --
|
| 361 |
+
_train_log_lines = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
+
# -- Generate tab handler --
|
| 364 |
def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
|
| 365 |
+
steps, lora_select, lm_model_select,
|
| 366 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 367 |
if not _server_ok():
|
| 368 |
return None, "ace-server not running. Check logs."
|
| 369 |
|
|
|
|
| 372 |
|
| 373 |
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
|
| 374 |
adapter = None if lora_select == "None (no LoRA)" else lora_select
|
| 375 |
+
lm_model = None if not lm_model_select or lm_model_select == "Default" else lm_model_select
|
| 376 |
|
| 377 |
progress_map = {
|
| 378 |
"lm_submit": (0.05, "Submitting LM job..."),
|
|
|
|
| 398 |
steps=steps,
|
| 399 |
output_format="mp3",
|
| 400 |
adapter=adapter,
|
| 401 |
+
lm_model=lm_model,
|
| 402 |
progress_cb=gr_progress,
|
| 403 |
)
|
| 404 |
return audio_path, status
|
|
|
|
| 417 |
lines.append(json.dumps(props, indent=2))
|
| 418 |
return "\n".join(lines)
|
| 419 |
|
| 420 |
+
# -- Training generator (direct integration, no subprocess) --
|
| 421 |
+
def train_lora_ui(audio_files, lora_name, epochs, lr, rank):
|
| 422 |
+
"""Generator that yields (train_log, train_btn_update, cancel_btn_update)."""
|
| 423 |
+
import gc as _gc
|
| 424 |
+
|
| 425 |
+
_train_log_lines.clear()
|
| 426 |
+
train_start = time.time()
|
| 427 |
+
|
| 428 |
+
def _log(msg):
|
| 429 |
+
_train_log_lines.append(msg)
|
| 430 |
|
| 431 |
+
def _log_text():
|
| 432 |
+
return "\n".join(_train_log_lines)
|
|
|
|
| 433 |
|
| 434 |
+
# -- Validation --
|
| 435 |
if not audio_files:
|
| 436 |
+
_log("[FAIL] No audio files uploaded.")
|
| 437 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 438 |
+
return
|
| 439 |
|
| 440 |
+
if len(audio_files) > MAX_AUDIO_FILES:
|
| 441 |
+
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
|
| 442 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 443 |
+
return
|
| 444 |
|
| 445 |
lora_name = (lora_name or "").strip() or "my-lora"
|
| 446 |
+
# Sanitize: alphanumeric, dash, underscore only
|
| 447 |
+
lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
|
| 448 |
+
|
| 449 |
epochs = max(1, min(int(epochs), 10))
|
| 450 |
lr = float(lr)
|
| 451 |
rank = max(1, min(int(rank), 64))
|
|
|
|
| 454 |
os.makedirs(output_dir, exist_ok=True)
|
| 455 |
audio_dir = os.path.join(output_dir, "audio_input")
|
| 456 |
os.makedirs(audio_dir, exist_ok=True)
|
| 457 |
+
|
| 458 |
+
# Copy uploaded audio files
|
| 459 |
+
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
|
| 460 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 461 |
+
|
| 462 |
for f in audio_files:
|
| 463 |
src = f.name if hasattr(f, "name") else str(f)
|
| 464 |
shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
|
| 465 |
|
| 466 |
+
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
|
| 467 |
+
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 468 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
|
|
|
| 469 |
|
| 470 |
+
# Stop ace-server before training (frees memory)
|
| 471 |
+
_log("[INFO] Stopping ace-server for training...")
|
| 472 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 473 |
+
_stop_ace_server()
|
| 474 |
+
_gc.collect()
|
|
|
|
| 475 |
|
| 476 |
+
try:
|
| 477 |
+
# -- Phase 1: Preprocessing --
|
| 478 |
+
_log("[Step 1/2] Preprocessing audio...")
|
| 479 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 480 |
+
|
| 481 |
+
preprocessed_dir = os.path.join(output_dir, "preprocessed_tensors")
|
| 482 |
+
|
| 483 |
+
def preprocess_progress(current, total, desc):
|
| 484 |
+
_log(f" {desc} ({current}/{total})")
|
| 485 |
+
|
| 486 |
+
result = preprocess_audio(
|
| 487 |
+
audio_dir=audio_dir,
|
| 488 |
+
output_dir=preprocessed_dir,
|
| 489 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 490 |
+
device="cpu",
|
| 491 |
+
variant="turbo",
|
| 492 |
+
max_duration=float(MAX_AUDIO_DURATION),
|
| 493 |
+
progress_callback=preprocess_progress,
|
| 494 |
+
cancel_check=lambda: False,
|
| 495 |
+
)
|
| 496 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 497 |
+
|
| 498 |
+
processed = result.get("processed", 0)
|
| 499 |
+
failed = result.get("failed", 0)
|
| 500 |
+
total = result.get("total", 0)
|
| 501 |
+
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
|
| 502 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 503 |
+
|
| 504 |
+
if processed == 0:
|
| 505 |
+
_log("[FAIL] No files preprocessed successfully. Cannot train.")
|
| 506 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 507 |
+
return
|
| 508 |
+
|
| 509 |
+
_gc.collect()
|
| 510 |
+
|
| 511 |
+
# -- Phase 2: Training --
|
| 512 |
+
_log("[Step 2/2] Training LoRA...")
|
| 513 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 514 |
+
|
| 515 |
+
for msg in train_lora_generator(
|
| 516 |
+
dataset_dir=preprocessed_dir,
|
| 517 |
+
output_dir=output_dir,
|
| 518 |
+
checkpoint_dir=ACE_CHECKPOINT_DIR,
|
| 519 |
+
epochs=epochs,
|
| 520 |
+
lr=lr,
|
| 521 |
+
rank=rank,
|
| 522 |
+
alpha=rank,
|
| 523 |
+
dropout=0.0,
|
| 524 |
+
batch_size=1,
|
| 525 |
+
gradient_accumulation_steps=4,
|
| 526 |
+
warmup_steps=100,
|
| 527 |
+
weight_decay=0.01,
|
| 528 |
+
max_grad_norm=1.0,
|
| 529 |
+
save_every_n_epochs=max(1, epochs // 2),
|
| 530 |
+
seed=42,
|
| 531 |
+
variant="turbo",
|
| 532 |
+
device="cpu",
|
| 533 |
+
log_every=5,
|
| 534 |
+
):
|
| 535 |
+
# Timeout check
|
| 536 |
+
elapsed = time.time() - train_start
|
| 537 |
+
if elapsed > MAX_TRAINING_TIME:
|
| 538 |
+
_log(f"[WARN] Training timed out after {int(elapsed)}s")
|
| 539 |
+
cancel_training()
|
| 540 |
+
break
|
| 541 |
+
|
| 542 |
+
_log(msg)
|
| 543 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 544 |
+
|
| 545 |
+
if msg.strip() == "[DONE]":
|
| 546 |
+
break
|
| 547 |
+
|
| 548 |
+
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 549 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 550 |
+
|
| 551 |
+
except Exception as exc:
|
| 552 |
+
_log(f"[FAIL] Training error: {exc}")
|
| 553 |
+
import traceback
|
| 554 |
+
_log(traceback.format_exc())
|
| 555 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 556 |
+
|
| 557 |
+
finally:
|
| 558 |
+
# Always restart ace-server
|
| 559 |
+
_log("[INFO] Restarting ace-server...")
|
| 560 |
+
yield _log_text(), gr.update(visible=False), gr.update(visible=True)
|
| 561 |
+
_gc.collect()
|
| 562 |
+
ok = _start_ace_server()
|
| 563 |
+
if ok:
|
| 564 |
+
_log("[OK] ace-server restarted successfully")
|
| 565 |
+
else:
|
| 566 |
+
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 567 |
+
yield _log_text(), gr.update(visible=True), gr.update(visible=False)
|
| 568 |
+
|
| 569 |
+
# -- Cancel handler --
|
| 570 |
+
def _on_cancel():
|
| 571 |
+
cancel_training()
|
| 572 |
+
logger.info("Cancel requested by user")
|
| 573 |
+
return "Cancelling after current epoch... please wait"
|
| 574 |
+
|
| 575 |
+
# -- Check log handler --
|
| 576 |
+
def _check_log():
|
| 577 |
+
if _train_log_lines:
|
| 578 |
+
return "\n".join(_train_log_lines)
|
| 579 |
+
return "No training log available."
|
| 580 |
+
|
| 581 |
+
# -- Build LM model choices --
|
| 582 |
+
def _lm_model_choices():
|
| 583 |
+
models = _scan_lm_models()
|
| 584 |
+
choices = ["Default"]
|
| 585 |
+
choices.extend(models)
|
| 586 |
+
return choices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
# -- Build UI --
|
| 589 |
CSS = """
|
|
|
|
| 638 |
value=8, step=1, scale=1,
|
| 639 |
)
|
| 640 |
seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
|
| 641 |
+
|
| 642 |
+
with gr.Row(elem_classes="compact-row"):
|
| 643 |
+
lora_select = gr.Dropdown(
|
| 644 |
+
label="LoRA", choices=_list_lora_choices(),
|
| 645 |
+
value="None (no LoRA)", scale=1,
|
| 646 |
+
)
|
| 647 |
+
lm_model_select = gr.Dropdown(
|
| 648 |
+
label="LM Model", choices=_lm_model_choices(),
|
| 649 |
+
value="Default", scale=1,
|
| 650 |
+
)
|
| 651 |
|
| 652 |
with gr.Row(elem_classes="compact-row"):
|
| 653 |
gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
|
|
|
|
| 656 |
gen_btn.click(
|
| 657 |
fn=generate_music,
|
| 658 |
inputs=[caption, lyrics, instrumental, bpm, duration,
|
| 659 |
+
seed, steps, lora_select, lm_model_select],
|
| 660 |
outputs=[audio_out, status],
|
| 661 |
api_name="generate",
|
| 662 |
)
|
|
|
|
| 674 |
with gr.Tab("Train LoRA"):
|
| 675 |
gr.Markdown(
|
| 676 |
"### LoRA Training\n"
|
| 677 |
+
"Fine-tune ACE-Step on your audio. "
|
| 678 |
+
"CPU training is slow -- ace-server stops during training."
|
| 679 |
)
|
| 680 |
|
| 681 |
with gr.Row(elem_classes="compact-row"):
|
|
|
|
| 687 |
)
|
| 688 |
with gr.Column(scale=1):
|
| 689 |
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
|
| 690 |
+
train_epochs = gr.Slider(
|
| 691 |
+
label="Epochs", minimum=1, maximum=10,
|
| 692 |
+
value=3, step=1,
|
| 693 |
+
)
|
| 694 |
+
train_lr = gr.Number(label="Learning Rate", value=1e-4)
|
| 695 |
+
train_rank = gr.Slider(
|
| 696 |
+
label="Rank (r)", minimum=1, maximum=64,
|
| 697 |
+
value=16, step=1,
|
| 698 |
+
)
|
| 699 |
|
| 700 |
with gr.Row(elem_classes="compact-row"):
|
| 701 |
train_btn = gr.Button("Train", variant="primary", scale=2)
|
| 702 |
+
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
|
| 703 |
log_btn = gr.Button("Check Log", scale=1)
|
| 704 |
+
|
| 705 |
train_log = gr.Textbox(
|
| 706 |
label="Training Log",
|
| 707 |
interactive=False,
|
|
|
|
| 709 |
elem_classes="status-box",
|
| 710 |
)
|
| 711 |
|
| 712 |
+
# Training generator -- yields (log, train_btn, cancel_btn)
|
| 713 |
+
train_event = train_btn.click(
|
| 714 |
+
train_lora_ui,
|
| 715 |
+
inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
|
| 716 |
+
outputs=[train_log, train_btn, cancel_btn],
|
| 717 |
api_name="train_lora",
|
| 718 |
+
concurrency_limit=1,
|
| 719 |
)
|
| 720 |
+
|
| 721 |
+
# After training completes, restore buttons and refresh LoRA dropdown
|
| 722 |
+
def _post_training():
|
| 723 |
+
return (
|
| 724 |
+
gr.update(visible=True),
|
| 725 |
+
gr.update(visible=False),
|
| 726 |
+
gr.update(choices=_list_lora_choices()),
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
train_event.then(
|
| 730 |
+
_post_training,
|
| 731 |
+
outputs=[train_btn, cancel_btn, lora_select],
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
# Cancel: set the flag, update status
|
| 735 |
+
cancel_btn.click(
|
| 736 |
+
_on_cancel,
|
| 737 |
+
outputs=[train_log],
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
# Check log: show last training output
|
| 741 |
log_btn.click(
|
| 742 |
+
_check_log,
|
|
|
|
| 743 |
outputs=[train_log],
|
| 744 |
+
api_name="check_log",
|
| 745 |
)
|
| 746 |
|
| 747 |
demo.launch(
|
train_engine.py
ADDED
|
@@ -0,0 +1,1353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Standalone ACE-Step CPU LoRA Training Engine.
|
| 3 |
+
|
| 4 |
+
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
|
| 5 |
+
module. No external Side-Step dependency required.
|
| 6 |
+
|
| 7 |
+
Exports:
|
| 8 |
+
preprocess_audio() - 2-pass sequential preprocessing
|
| 9 |
+
train_lora_generator() - Generator-based LoRA training loop
|
| 10 |
+
cancel_training() - Set the cancel flag
|
| 11 |
+
get_trained_loras() - List saved adapters
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import random
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
import types
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
from torch.optim import AdamW
|
| 33 |
+
from torch.optim.lr_scheduler import (
|
| 34 |
+
CosineAnnealingLR,
|
| 35 |
+
LinearLR,
|
| 36 |
+
SequentialLR,
|
| 37 |
+
)
|
| 38 |
+
from torch.utils.data import DataLoader, Dataset
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
# Configurable caps (edit these at the top of the file)
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
|
| 46 |
+
MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
|
| 47 |
+
MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
|
| 48 |
+
TARGET_SR = 48000
|
| 49 |
+
AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
|
| 50 |
+
|
| 51 |
+
# bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
|
| 52 |
+
CPU_DTYPE = torch.float32
|
| 53 |
+
|
| 54 |
+
import threading
|
| 55 |
+
_training_cancel = threading.Event()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def cancel_training() -> None:
|
| 59 |
+
_training_cancel.set()
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ============================================================================
|
| 63 |
+
# CONFIGS
|
| 64 |
+
# ============================================================================
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class LoRAConfig:
|
| 68 |
+
r: int = 64
|
| 69 |
+
alpha: int = 128
|
| 70 |
+
dropout: float = 0.1
|
| 71 |
+
target_modules: List[str] = field(default_factory=lambda: [
|
| 72 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 73 |
+
])
|
| 74 |
+
bias: str = "none"
|
| 75 |
+
attention_type: str = "both"
|
| 76 |
+
target_mlp: bool = True
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ============================================================================
|
| 80 |
+
# TIMESTEP SAMPLING & CFG DROPOUT
|
| 81 |
+
# ============================================================================
|
| 82 |
+
|
| 83 |
+
def sample_timesteps(
|
| 84 |
+
batch_size: int,
|
| 85 |
+
device: torch.device,
|
| 86 |
+
dtype: torch.dtype,
|
| 87 |
+
timestep_mu: float = -0.4,
|
| 88 |
+
timestep_sigma: float = 1.0,
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
t = torch.sigmoid(
|
| 91 |
+
torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
|
| 92 |
+
)
|
| 93 |
+
r = torch.sigmoid(
|
| 94 |
+
torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
|
| 95 |
+
)
|
| 96 |
+
t, r = torch.maximum(t, r), torch.minimum(t, r)
|
| 97 |
+
# use_meanflow=False forces r=t (ACE-Step convention)
|
| 98 |
+
return t, t
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def apply_cfg_dropout(
|
| 102 |
+
encoder_hidden_states: torch.Tensor,
|
| 103 |
+
null_condition_emb: torch.Tensor,
|
| 104 |
+
cfg_ratio: float = 0.15,
|
| 105 |
+
) -> torch.Tensor:
|
| 106 |
+
bsz = encoder_hidden_states.shape[0]
|
| 107 |
+
device = encoder_hidden_states.device
|
| 108 |
+
dtype = encoder_hidden_states.dtype
|
| 109 |
+
mask = torch.where(
|
| 110 |
+
torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio,
|
| 111 |
+
torch.zeros(size=(bsz,), device=device, dtype=dtype),
|
| 112 |
+
torch.ones(size=(bsz,), device=device, dtype=dtype),
|
| 113 |
+
).view(-1, 1, 1)
|
| 114 |
+
return torch.where(
|
| 115 |
+
mask > 0,
|
| 116 |
+
encoder_hidden_states,
|
| 117 |
+
null_condition_emb.expand_as(encoder_hidden_states),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ============================================================================
|
| 122 |
+
# OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param)
|
| 123 |
+
# ============================================================================
|
| 124 |
+
|
| 125 |
+
def build_optimizer(
|
| 126 |
+
params, lr: float = 1e-4, weight_decay: float = 0.01,
|
| 127 |
+
) -> torch.optim.Optimizer:
|
| 128 |
+
try:
|
| 129 |
+
from transformers.optimization import Adafactor
|
| 130 |
+
logger.info("Using Adafactor optimizer (minimal state memory)")
|
| 131 |
+
return Adafactor(
|
| 132 |
+
params, lr=lr, weight_decay=weight_decay,
|
| 133 |
+
scale_parameter=False, relative_step=False,
|
| 134 |
+
)
|
| 135 |
+
except ImportError:
|
| 136 |
+
logger.warning("transformers not installed, falling back to AdamW")
|
| 137 |
+
return AdamW(params, lr=lr, weight_decay=weight_decay)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def build_scheduler(
|
| 141 |
+
optimizer, total_steps: int, warmup_steps: int, lr: float,
|
| 142 |
+
):
|
| 143 |
+
_max_warmup = max(1, total_steps // 10)
|
| 144 |
+
if warmup_steps > _max_warmup:
|
| 145 |
+
warmup_steps = _max_warmup
|
| 146 |
+
|
| 147 |
+
warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
|
| 148 |
+
remaining = max(1, total_steps - warmup_steps)
|
| 149 |
+
main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01)
|
| 150 |
+
return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps])
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ============================================================================
|
| 154 |
+
# DATASET
|
| 155 |
+
# ============================================================================
|
| 156 |
+
|
| 157 |
+
def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
|
| 158 |
+
max_t = max(s["target_latents"].shape[0] for s in batch)
|
| 159 |
+
max_e = max(s["encoder_hidden_states"].shape[0] for s in batch)
|
| 160 |
+
|
| 161 |
+
def pad(t, max_len, dim=0):
|
| 162 |
+
diff = max_len - t.shape[dim]
|
| 163 |
+
if diff <= 0:
|
| 164 |
+
return t
|
| 165 |
+
shape = list(t.shape)
|
| 166 |
+
shape[dim] = diff
|
| 167 |
+
return torch.cat([t, t.new_zeros(*shape)], dim=dim)
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]),
|
| 171 |
+
"attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]),
|
| 172 |
+
"encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]),
|
| 173 |
+
"encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]),
|
| 174 |
+
"context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TensorDataset(Dataset):
|
| 179 |
+
_REQUIRED = frozenset([
|
| 180 |
+
"target_latents", "attention_mask", "encoder_hidden_states",
|
| 181 |
+
"encoder_attention_mask", "context_latents",
|
| 182 |
+
])
|
| 183 |
+
|
| 184 |
+
def __init__(self, tensor_dir: str):
|
| 185 |
+
self.paths: List[str] = []
|
| 186 |
+
for f in sorted(os.listdir(tensor_dir)):
|
| 187 |
+
if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json":
|
| 188 |
+
self.paths.append(str(Path(tensor_dir) / f))
|
| 189 |
+
|
| 190 |
+
def __len__(self) -> int:
|
| 191 |
+
return len(self.paths)
|
| 192 |
+
|
| 193 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 194 |
+
data = torch.load(self.paths[idx], map_location="cpu", weights_only=True)
|
| 195 |
+
missing = self._REQUIRED - data.keys()
|
| 196 |
+
if missing:
|
| 197 |
+
raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}")
|
| 198 |
+
for k in ("target_latents", "encoder_hidden_states", "context_latents"):
|
| 199 |
+
t = data[k]
|
| 200 |
+
if torch.isnan(t).any() or torch.isinf(t).any():
|
| 201 |
+
t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
|
| 202 |
+
return {k: data[k] for k in self._REQUIRED}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ============================================================================
|
| 206 |
+
# GRADIENT CHECKPOINTING
|
| 207 |
+
# ============================================================================
|
| 208 |
+
|
| 209 |
+
def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]:
|
| 210 |
+
for attr in ("layers", "blocks", "transformer_blocks"):
|
| 211 |
+
c = getattr(decoder, attr, None)
|
| 212 |
+
if isinstance(c, nn.ModuleList) and len(c) > 0:
|
| 213 |
+
return c
|
| 214 |
+
for child in decoder.children():
|
| 215 |
+
for attr in ("layers", "blocks", "transformer_blocks"):
|
| 216 |
+
c = getattr(child, attr, None)
|
| 217 |
+
if isinstance(c, nn.ModuleList) and len(c) > 0:
|
| 218 |
+
return c
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def enable_gradient_checkpointing(decoder: nn.Module) -> bool:
|
| 223 |
+
"""Enable gradient checkpointing on the decoder to save memory."""
|
| 224 |
+
enabled = False
|
| 225 |
+
|
| 226 |
+
# Walk wrapper chain
|
| 227 |
+
stack = [decoder]
|
| 228 |
+
visited = set()
|
| 229 |
+
while stack:
|
| 230 |
+
mod = stack.pop()
|
| 231 |
+
if not isinstance(mod, nn.Module):
|
| 232 |
+
continue
|
| 233 |
+
mid = id(mod)
|
| 234 |
+
if mid in visited:
|
| 235 |
+
continue
|
| 236 |
+
visited.add(mid)
|
| 237 |
+
|
| 238 |
+
if hasattr(mod, "gradient_checkpointing_enable"):
|
| 239 |
+
try:
|
| 240 |
+
mod.gradient_checkpointing_enable()
|
| 241 |
+
enabled = True
|
| 242 |
+
except Exception:
|
| 243 |
+
pass
|
| 244 |
+
elif hasattr(mod, "gradient_checkpointing"):
|
| 245 |
+
try:
|
| 246 |
+
mod.gradient_checkpointing = True
|
| 247 |
+
enabled = True
|
| 248 |
+
except Exception:
|
| 249 |
+
pass
|
| 250 |
+
|
| 251 |
+
if hasattr(mod, "enable_input_require_grads"):
|
| 252 |
+
try:
|
| 253 |
+
mod.enable_input_require_grads()
|
| 254 |
+
except Exception:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
cfg = getattr(mod, "config", None)
|
| 258 |
+
if cfg is not None and hasattr(cfg, "use_cache"):
|
| 259 |
+
try:
|
| 260 |
+
cfg.use_cache = False
|
| 261 |
+
except Exception:
|
| 262 |
+
pass
|
| 263 |
+
|
| 264 |
+
for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
|
| 265 |
+
child = getattr(mod, a, None)
|
| 266 |
+
if isinstance(child, nn.Module):
|
| 267 |
+
stack.append(child)
|
| 268 |
+
|
| 269 |
+
return enabled
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def force_disable_cache(decoder: nn.Module) -> None:
|
| 273 |
+
stack = [decoder]
|
| 274 |
+
visited = set()
|
| 275 |
+
while stack:
|
| 276 |
+
mod = stack.pop()
|
| 277 |
+
if not isinstance(mod, nn.Module):
|
| 278 |
+
continue
|
| 279 |
+
mid = id(mod)
|
| 280 |
+
if mid in visited:
|
| 281 |
+
continue
|
| 282 |
+
visited.add(mid)
|
| 283 |
+
cfg = getattr(mod, "config", None)
|
| 284 |
+
if cfg is not None and hasattr(cfg, "use_cache"):
|
| 285 |
+
try:
|
| 286 |
+
cfg.use_cache = False
|
| 287 |
+
except Exception:
|
| 288 |
+
pass
|
| 289 |
+
for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
|
| 290 |
+
child = getattr(mod, a, None)
|
| 291 |
+
if isinstance(child, nn.Module):
|
| 292 |
+
stack.append(child)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# ============================================================================
|
| 296 |
+
# LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
|
| 297 |
+
# ============================================================================
|
| 298 |
+
|
| 299 |
+
def _unwrap_decoder(model):
|
| 300 |
+
decoder = model.decoder if hasattr(model, "decoder") else model
|
| 301 |
+
while hasattr(decoder, "_forward_module"):
|
| 302 |
+
decoder = decoder._forward_module
|
| 303 |
+
if hasattr(decoder, "base_model"):
|
| 304 |
+
bm = decoder.base_model
|
| 305 |
+
decoder = bm.model if hasattr(bm, "model") else bm
|
| 306 |
+
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
|
| 307 |
+
decoder = decoder.model
|
| 308 |
+
return decoder
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]:
|
| 312 |
+
from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType
|
| 313 |
+
|
| 314 |
+
decoder = _unwrap_decoder(model)
|
| 315 |
+
model.decoder = decoder
|
| 316 |
+
|
| 317 |
+
# Guard enable_input_require_grads for DiT (no get_input_embeddings)
|
| 318 |
+
if hasattr(decoder, "enable_input_require_grads"):
|
| 319 |
+
orig = decoder.enable_input_require_grads
|
| 320 |
+
|
| 321 |
+
def _safe(self):
|
| 322 |
+
try:
|
| 323 |
+
return orig()
|
| 324 |
+
except NotImplementedError:
|
| 325 |
+
return None
|
| 326 |
+
|
| 327 |
+
decoder.enable_input_require_grads = types.MethodType(_safe, decoder)
|
| 328 |
+
|
| 329 |
+
if hasattr(decoder, "is_gradient_checkpointing"):
|
| 330 |
+
try:
|
| 331 |
+
decoder.is_gradient_checkpointing = False
|
| 332 |
+
except Exception:
|
| 333 |
+
pass
|
| 334 |
+
|
| 335 |
+
peft_cfg = PeftLoraConfig(
|
| 336 |
+
r=lora_cfg.r,
|
| 337 |
+
lora_alpha=lora_cfg.alpha,
|
| 338 |
+
lora_dropout=lora_cfg.dropout,
|
| 339 |
+
target_modules=lora_cfg.target_modules,
|
| 340 |
+
bias=lora_cfg.bias,
|
| 341 |
+
task_type=TaskType.FEATURE_EXTRACTION,
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
model.decoder = get_peft_model(decoder, peft_cfg)
|
| 345 |
+
|
| 346 |
+
for name, param in model.named_parameters():
|
| 347 |
+
if "lora_" not in name:
|
| 348 |
+
param.requires_grad = False
|
| 349 |
+
|
| 350 |
+
total = sum(p.numel() for p in model.parameters())
|
| 351 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 352 |
+
|
| 353 |
+
return model, {
|
| 354 |
+
"total_params": total,
|
| 355 |
+
"trainable_params": trainable,
|
| 356 |
+
"trainable_ratio": trainable / total if total > 0 else 0,
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def save_lora_adapter(model, output_dir: str) -> None:
|
| 361 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 362 |
+
decoder = model.decoder if hasattr(model, "decoder") else model
|
| 363 |
+
while hasattr(decoder, "_forward_module"):
|
| 364 |
+
decoder = decoder._forward_module
|
| 365 |
+
|
| 366 |
+
if hasattr(decoder, "save_pretrained"):
|
| 367 |
+
decoder.save_pretrained(output_dir)
|
| 368 |
+
# Scrub base_model path for portability
|
| 369 |
+
cfg_path = os.path.join(output_dir, "adapter_config.json")
|
| 370 |
+
if os.path.isfile(cfg_path):
|
| 371 |
+
try:
|
| 372 |
+
with open(cfg_path, "r") as f:
|
| 373 |
+
cfg = json.load(f)
|
| 374 |
+
if cfg.get("base_model_name_or_path"):
|
| 375 |
+
cfg["base_model_name_or_path"] = ""
|
| 376 |
+
with open(cfg_path, "w") as f:
|
| 377 |
+
json.dump(cfg, f, indent=2)
|
| 378 |
+
except Exception:
|
| 379 |
+
pass
|
| 380 |
+
logger.info("LoRA adapter saved to %s", output_dir)
|
| 381 |
+
else:
|
| 382 |
+
# Fallback: manual extraction
|
| 383 |
+
state = {}
|
| 384 |
+
for name, param in decoder.named_parameters():
|
| 385 |
+
if "lora_" in name:
|
| 386 |
+
state[name] = param.data.clone()
|
| 387 |
+
if state:
|
| 388 |
+
try:
|
| 389 |
+
from safetensors.torch import save_file
|
| 390 |
+
save_file(state, str(Path(output_dir) / "adapter_model.safetensors"))
|
| 391 |
+
except ImportError:
|
| 392 |
+
torch.save(state, str(Path(output_dir) / "lora_weights.pt"))
|
| 393 |
+
logger.info("LoRA adapter saved (fallback) to %s", output_dir)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# ============================================================================
|
| 397 |
+
# MODEL LOADING (FA2 -> SDPA -> eager fallback)
|
| 398 |
+
# ============================================================================
|
| 399 |
+
|
| 400 |
+
_VARIANT_DIR = {
|
| 401 |
+
"turbo": "acestep-v15-turbo",
|
| 402 |
+
"base": "acestep-v15-base",
|
| 403 |
+
"sft": "acestep-v15-sft",
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path:
|
| 408 |
+
base = Path(checkpoint_dir).resolve()
|
| 409 |
+
subdir = _VARIANT_DIR.get(variant)
|
| 410 |
+
if subdir:
|
| 411 |
+
p = (Path(checkpoint_dir) / subdir).resolve()
|
| 412 |
+
if p.is_dir():
|
| 413 |
+
return p
|
| 414 |
+
p = (Path(checkpoint_dir) / variant).resolve()
|
| 415 |
+
if p.is_dir():
|
| 416 |
+
return p
|
| 417 |
+
raise FileNotFoundError(
|
| 418 |
+
f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} "
|
| 419 |
+
f"and {variant!r} under {checkpoint_dir}"
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def _ensure_acestep_imports():
|
| 424 |
+
"""Register stub modules so AutoModel can load ACE-Step checkpoints."""
|
| 425 |
+
for name in (
|
| 426 |
+
"acestep", "acestep.models", "acestep.models.common",
|
| 427 |
+
"acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft",
|
| 428 |
+
):
|
| 429 |
+
if name not in sys.modules:
|
| 430 |
+
stub = types.ModuleType(name)
|
| 431 |
+
stub.__path__ = []
|
| 432 |
+
sys.modules[name] = stub
|
| 433 |
+
|
| 434 |
+
# Try to load real modules from adjacent ACE-Step checkout
|
| 435 |
+
for name in (
|
| 436 |
+
"acestep.models.common.configuration_acestep_v15",
|
| 437 |
+
"acestep.models.common.apg_guidance",
|
| 438 |
+
):
|
| 439 |
+
if name not in sys.modules:
|
| 440 |
+
sys.modules[name] = types.ModuleType(name)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _attn_candidates(device: str) -> List[str]:
|
| 444 |
+
"""FA2 -> SDPA -> eager, filtered by availability."""
|
| 445 |
+
candidates = []
|
| 446 |
+
if device.startswith("cuda"):
|
| 447 |
+
try:
|
| 448 |
+
import flash_attn # noqa: F401
|
| 449 |
+
dev_idx = int(device.split(":")[1]) if ":" in device else 0
|
| 450 |
+
props = torch.cuda.get_device_properties(dev_idx)
|
| 451 |
+
if props.major >= 8:
|
| 452 |
+
candidates.append("flash_attention_2")
|
| 453 |
+
except (ImportError, Exception):
|
| 454 |
+
pass
|
| 455 |
+
candidates.extend(["sdpa", "eager"])
|
| 456 |
+
return candidates
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
def load_model_for_training(
|
| 460 |
+
checkpoint_dir: str, variant: str = "base", device: str = "cpu",
|
| 461 |
+
) -> Any:
|
| 462 |
+
from transformers import AutoModel
|
| 463 |
+
|
| 464 |
+
model_dir = _resolve_model_dir(checkpoint_dir, variant)
|
| 465 |
+
# CPU always uses float32
|
| 466 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 467 |
+
|
| 468 |
+
_ensure_acestep_imports()
|
| 469 |
+
|
| 470 |
+
candidates = _attn_candidates(device)
|
| 471 |
+
model = None
|
| 472 |
+
last_err = None
|
| 473 |
+
|
| 474 |
+
for idx, attn in enumerate(candidates):
|
| 475 |
+
try:
|
| 476 |
+
load_kwargs = dict(
|
| 477 |
+
trust_remote_code=True,
|
| 478 |
+
attn_implementation=attn,
|
| 479 |
+
torch_dtype=dtype,
|
| 480 |
+
low_cpu_mem_usage=False,
|
| 481 |
+
)
|
| 482 |
+
if device != "cpu":
|
| 483 |
+
load_kwargs["device_map"] = {"": device}
|
| 484 |
+
model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
|
| 485 |
+
logger.info("Model loaded with attn_implementation=%s", attn)
|
| 486 |
+
break
|
| 487 |
+
except Exception as exc:
|
| 488 |
+
err_text = str(exc)
|
| 489 |
+
if "packages that were not found" in err_text or "No module named" in err_text:
|
| 490 |
+
raise RuntimeError(
|
| 491 |
+
f"Model files in {model_dir} require a missing Python package.\n"
|
| 492 |
+
f" Original error: {err_text}"
|
| 493 |
+
) from exc
|
| 494 |
+
last_err = exc
|
| 495 |
+
logger.warning("attn backend '%s' failed: %s", attn, exc)
|
| 496 |
+
|
| 497 |
+
if model is None:
|
| 498 |
+
raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
|
| 499 |
+
|
| 500 |
+
for param in model.parameters():
|
| 501 |
+
param.requires_grad = False
|
| 502 |
+
model.eval()
|
| 503 |
+
return model
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def load_vae(checkpoint_dir: str, device: str = "cpu"):
|
| 507 |
+
from diffusers.models import AutoencoderOobleck
|
| 508 |
+
|
| 509 |
+
vae_path = Path(checkpoint_dir) / "vae"
|
| 510 |
+
if not vae_path.is_dir():
|
| 511 |
+
raise FileNotFoundError(f"VAE directory not found: {vae_path}")
|
| 512 |
+
|
| 513 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 514 |
+
vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
|
| 515 |
+
vae = vae.to(device=device)
|
| 516 |
+
vae.eval()
|
| 517 |
+
return vae
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def load_text_encoder(checkpoint_dir: str, device: str = "cpu"):
|
| 521 |
+
from transformers import AutoModel, AutoTokenizer
|
| 522 |
+
|
| 523 |
+
text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B"
|
| 524 |
+
if not text_path.is_dir():
|
| 525 |
+
raise FileNotFoundError(f"Text encoder not found: {text_path}")
|
| 526 |
+
|
| 527 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 528 |
+
tokenizer = AutoTokenizer.from_pretrained(str(text_path))
|
| 529 |
+
encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
|
| 530 |
+
encoder = encoder.to(device=device)
|
| 531 |
+
encoder.eval()
|
| 532 |
+
return tokenizer, encoder
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
def load_silence_latent(
|
| 536 |
+
checkpoint_dir: str, device: str = "cpu", variant: str = "base",
|
| 537 |
+
) -> torch.Tensor:
|
| 538 |
+
ckpt = Path(checkpoint_dir)
|
| 539 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 540 |
+
|
| 541 |
+
candidates = [ckpt / "silence_latent.pt"]
|
| 542 |
+
subdir = _VARIANT_DIR.get(variant)
|
| 543 |
+
if subdir:
|
| 544 |
+
candidates.append(ckpt / subdir / "silence_latent.pt")
|
| 545 |
+
for sd in _VARIANT_DIR.values():
|
| 546 |
+
candidates.append(ckpt / sd / "silence_latent.pt")
|
| 547 |
+
|
| 548 |
+
for c in candidates:
|
| 549 |
+
if c.is_file():
|
| 550 |
+
sl = torch.load(str(c), weights_only=True).transpose(1, 2)
|
| 551 |
+
return sl.to(device=device, dtype=dtype)
|
| 552 |
+
|
| 553 |
+
raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}")
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def unload_models(*models) -> None:
|
| 557 |
+
for obj in models:
|
| 558 |
+
if obj is None:
|
| 559 |
+
continue
|
| 560 |
+
if hasattr(obj, "to"):
|
| 561 |
+
try:
|
| 562 |
+
obj.to("cpu")
|
| 563 |
+
except Exception:
|
| 564 |
+
pass
|
| 565 |
+
del obj
|
| 566 |
+
gc.collect()
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# ============================================================================
|
| 570 |
+
# AUDIO LOADING
|
| 571 |
+
# ============================================================================
|
| 572 |
+
|
| 573 |
+
def load_audio_stereo(
|
| 574 |
+
audio_path: str, target_sr: int, max_duration: float,
|
| 575 |
+
) -> Tuple[torch.Tensor, int]:
|
| 576 |
+
import numpy as np
|
| 577 |
+
|
| 578 |
+
try:
|
| 579 |
+
import soundfile as sf
|
| 580 |
+
data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
|
| 581 |
+
audio_np = np.ascontiguousarray(data.T)
|
| 582 |
+
sr = int(sr)
|
| 583 |
+
if sr != target_sr:
|
| 584 |
+
import librosa
|
| 585 |
+
audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1)
|
| 586 |
+
sr = target_sr
|
| 587 |
+
audio = torch.from_numpy(np.ascontiguousarray(audio_np))
|
| 588 |
+
except Exception:
|
| 589 |
+
import torchaudio
|
| 590 |
+
audio, sr = torchaudio.load(audio_path)
|
| 591 |
+
sr = int(sr)
|
| 592 |
+
if sr != target_sr:
|
| 593 |
+
audio = torchaudio.transforms.Resample(sr, target_sr)(audio)
|
| 594 |
+
sr = target_sr
|
| 595 |
+
|
| 596 |
+
if audio.shape[0] == 1:
|
| 597 |
+
audio = audio.repeat(2, 1)
|
| 598 |
+
elif audio.shape[0] > 2:
|
| 599 |
+
audio = audio[:2, :]
|
| 600 |
+
|
| 601 |
+
max_samples = int(max_duration * target_sr)
|
| 602 |
+
if audio.shape[1] > max_samples:
|
| 603 |
+
audio = audio[:, :max_samples]
|
| 604 |
+
|
| 605 |
+
return audio, sr
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
# ============================================================================
|
| 609 |
+
# TEXT / LYRICS ENCODING
|
| 610 |
+
# ============================================================================
|
| 611 |
+
|
| 612 |
+
def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype):
|
| 613 |
+
inputs = tokenizer(
|
| 614 |
+
text_prompt, padding="max_length", max_length=256,
|
| 615 |
+
truncation=True, return_tensors="pt",
|
| 616 |
+
)
|
| 617 |
+
ids = inputs.input_ids.to(device)
|
| 618 |
+
mask = inputs.attention_mask.to(device).to(dtype)
|
| 619 |
+
|
| 620 |
+
enc_dev = next(text_encoder.parameters()).device
|
| 621 |
+
if ids.device != enc_dev:
|
| 622 |
+
ids = ids.to(enc_dev)
|
| 623 |
+
mask = mask.to(enc_dev)
|
| 624 |
+
|
| 625 |
+
with torch.no_grad():
|
| 626 |
+
hs = text_encoder(ids).last_hidden_state.to(dtype)
|
| 627 |
+
return hs, mask
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
|
| 631 |
+
inputs = tokenizer(
|
| 632 |
+
lyrics, padding="max_length", max_length=512,
|
| 633 |
+
truncation=True, return_tensors="pt",
|
| 634 |
+
)
|
| 635 |
+
ids = inputs.input_ids.to(device)
|
| 636 |
+
mask = inputs.attention_mask.to(device).to(dtype)
|
| 637 |
+
|
| 638 |
+
enc_dev = next(text_encoder.parameters()).device
|
| 639 |
+
if ids.device != enc_dev:
|
| 640 |
+
ids = ids.to(enc_dev)
|
| 641 |
+
mask = mask.to(enc_dev)
|
| 642 |
+
|
| 643 |
+
with torch.no_grad():
|
| 644 |
+
hs = text_encoder.embed_tokens(ids).to(dtype)
|
| 645 |
+
return hs, mask
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
# ============================================================================
|
| 649 |
+
# VAE TILED ENCODING
|
| 650 |
+
# ============================================================================
|
| 651 |
+
|
| 652 |
+
def tiled_vae_encode(
|
| 653 |
+
vae, audio: torch.Tensor, dtype: torch.dtype,
|
| 654 |
+
chunk_size: Optional[int] = None, overlap: int = 96000,
|
| 655 |
+
) -> torch.Tensor:
|
| 656 |
+
vae_device = next(vae.parameters()).device
|
| 657 |
+
vae_dtype = vae.dtype
|
| 658 |
+
|
| 659 |
+
if chunk_size is None:
|
| 660 |
+
chunk_size = TARGET_SR * 30
|
| 661 |
+
|
| 662 |
+
B, C, S = audio.shape
|
| 663 |
+
|
| 664 |
+
if S <= chunk_size:
|
| 665 |
+
vae_input = audio.to(vae_device, dtype=vae_dtype)
|
| 666 |
+
with torch.inference_mode():
|
| 667 |
+
latents = vae.encode(vae_input).latent_dist.sample()
|
| 668 |
+
return latents.transpose(1, 2).to(dtype)
|
| 669 |
+
|
| 670 |
+
stride = chunk_size - 2 * overlap
|
| 671 |
+
if stride <= 0:
|
| 672 |
+
raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})")
|
| 673 |
+
|
| 674 |
+
num_steps = math.ceil(S / stride)
|
| 675 |
+
ds_factor = None
|
| 676 |
+
write_pos = 0
|
| 677 |
+
final = None
|
| 678 |
+
|
| 679 |
+
for i in range(num_steps):
|
| 680 |
+
core_start = i * stride
|
| 681 |
+
core_end = min(core_start + stride, S)
|
| 682 |
+
win_start = max(0, core_start - overlap)
|
| 683 |
+
win_end = min(S, core_end + overlap)
|
| 684 |
+
|
| 685 |
+
chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype)
|
| 686 |
+
with torch.inference_mode():
|
| 687 |
+
lat = vae.encode(chunk).latent_dist.sample()
|
| 688 |
+
|
| 689 |
+
if ds_factor is None:
|
| 690 |
+
ds_factor = chunk.shape[-1] / lat.shape[-1]
|
| 691 |
+
total_len = int(round(S / ds_factor))
|
| 692 |
+
final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu")
|
| 693 |
+
|
| 694 |
+
trim_start = int(round((core_start - win_start) / ds_factor))
|
| 695 |
+
trim_end = int(round((win_end - core_end) / ds_factor))
|
| 696 |
+
end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1]
|
| 697 |
+
core = lat[:, :, trim_start:end_idx]
|
| 698 |
+
core_len = core.shape[-1]
|
| 699 |
+
final[:, :, write_pos:write_pos + core_len] = core.cpu()
|
| 700 |
+
write_pos += core_len
|
| 701 |
+
del chunk, lat, core
|
| 702 |
+
|
| 703 |
+
final = final[:, :, :write_pos]
|
| 704 |
+
return final.transpose(1, 2).to(dtype)
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
# ============================================================================
|
| 708 |
+
# ENCODER / CONTEXT HELPERS
|
| 709 |
+
# ============================================================================
|
| 710 |
+
|
| 711 |
+
def run_encoder(
|
| 712 |
+
model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype,
|
| 713 |
+
):
|
| 714 |
+
refer = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
| 715 |
+
order_mask = torch.zeros(1, device=device, dtype=torch.long)
|
| 716 |
+
|
| 717 |
+
with torch.no_grad():
|
| 718 |
+
enc_hs, enc_mask = model.encoder(
|
| 719 |
+
text_hidden_states=text_hs,
|
| 720 |
+
text_attention_mask=text_mask,
|
| 721 |
+
lyric_hidden_states=lyric_hs,
|
| 722 |
+
lyric_attention_mask=lyric_mask,
|
| 723 |
+
refer_audio_acoustic_hidden_states_packed=refer,
|
| 724 |
+
refer_audio_order_mask=order_mask,
|
| 725 |
+
)
|
| 726 |
+
return enc_hs, enc_mask
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def build_context_latents(silence_latent, latent_length: int, device, dtype):
|
| 730 |
+
src = silence_latent[:, :latent_length, :].to(dtype)
|
| 731 |
+
if src.shape[0] < 1:
|
| 732 |
+
src = src.expand(1, -1, -1)
|
| 733 |
+
if src.shape[1] < latent_length:
|
| 734 |
+
pad_len = latent_length - src.shape[1]
|
| 735 |
+
src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1)
|
| 736 |
+
elif src.shape[1] > latent_length:
|
| 737 |
+
src = src[:, :latent_length, :]
|
| 738 |
+
masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 739 |
+
return torch.cat([src, masks], dim=-1)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
# ============================================================================
|
| 743 |
+
# AUDIO DISCOVERY
|
| 744 |
+
# ============================================================================
|
| 745 |
+
|
| 746 |
+
def _discover_audio_files(audio_dir: str) -> List[Path]:
|
| 747 |
+
files = []
|
| 748 |
+
for root, _, names in os.walk(audio_dir):
|
| 749 |
+
for name in sorted(names):
|
| 750 |
+
if Path(name).suffix.lower() in AUDIO_EXTENSIONS:
|
| 751 |
+
files.append(Path(root) / name)
|
| 752 |
+
return files
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
def _detect_max_duration(files: List[Path]) -> float:
|
| 756 |
+
"""Return the longest audio file duration (capped at MAX_AUDIO_DURATION)."""
|
| 757 |
+
max_dur = 0.0
|
| 758 |
+
try:
|
| 759 |
+
import soundfile as sf
|
| 760 |
+
for f in files[:50]:
|
| 761 |
+
try:
|
| 762 |
+
info = sf.info(str(f))
|
| 763 |
+
max_dur = max(max_dur, info.duration)
|
| 764 |
+
except Exception:
|
| 765 |
+
pass
|
| 766 |
+
except ImportError:
|
| 767 |
+
pass
|
| 768 |
+
return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
|
| 769 |
+
|
| 770 |
+
|
| 771 |
+
# ============================================================================
|
| 772 |
+
# PREPROCESSING (2-pass sequential)
|
| 773 |
+
# ============================================================================
|
| 774 |
+
|
| 775 |
+
def preprocess_audio(
|
| 776 |
+
audio_dir: str,
|
| 777 |
+
output_dir: str,
|
| 778 |
+
checkpoint_dir: str,
|
| 779 |
+
device: str = "cpu",
|
| 780 |
+
variant: str = "base",
|
| 781 |
+
max_duration: float = 0,
|
| 782 |
+
progress_callback: Optional[Callable] = None,
|
| 783 |
+
cancel_check: Optional[Callable] = None,
|
| 784 |
+
) -> Dict[str, Any]:
|
| 785 |
+
"""2-pass sequential preprocessing.
|
| 786 |
+
|
| 787 |
+
Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
|
| 788 |
+
Pass 2: Load DIT model, run encoder, build context, save final .pt files.
|
| 789 |
+
"""
|
| 790 |
+
out = Path(output_dir)
|
| 791 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 792 |
+
|
| 793 |
+
# Clean orphaned staging files
|
| 794 |
+
for orphan in out.glob("*.__writing__"):
|
| 795 |
+
try:
|
| 796 |
+
orphan.unlink()
|
| 797 |
+
except OSError:
|
| 798 |
+
pass
|
| 799 |
+
|
| 800 |
+
audio_files = _discover_audio_files(audio_dir)
|
| 801 |
+
if not audio_files:
|
| 802 |
+
return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)}
|
| 803 |
+
|
| 804 |
+
total = len(audio_files)
|
| 805 |
+
|
| 806 |
+
if max_duration <= 0:
|
| 807 |
+
max_duration = _detect_max_duration(audio_files)
|
| 808 |
+
|
| 809 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 810 |
+
|
| 811 |
+
# ---- Pass 1: VAE + Text Encoder ----
|
| 812 |
+
logger.info("Pass 1/2: Loading VAE + Text Encoder...")
|
| 813 |
+
vae = load_vae(checkpoint_dir, device)
|
| 814 |
+
tokenizer, text_enc = load_text_encoder(checkpoint_dir, device)
|
| 815 |
+
silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant)
|
| 816 |
+
|
| 817 |
+
intermediates: List[Path] = []
|
| 818 |
+
p1_failed = 0
|
| 819 |
+
|
| 820 |
+
try:
|
| 821 |
+
for i, af in enumerate(audio_files):
|
| 822 |
+
if cancel_check and cancel_check():
|
| 823 |
+
break
|
| 824 |
+
|
| 825 |
+
stem = af.stem
|
| 826 |
+
final_pt = out / f"{stem}.pt"
|
| 827 |
+
if final_pt.exists():
|
| 828 |
+
continue
|
| 829 |
+
|
| 830 |
+
try:
|
| 831 |
+
audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
|
| 832 |
+
audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
|
| 833 |
+
|
| 834 |
+
with torch.no_grad():
|
| 835 |
+
target_latents = tiled_vae_encode(vae, audio, dtype)
|
| 836 |
+
del audio
|
| 837 |
+
|
| 838 |
+
if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
|
| 839 |
+
p1_failed += 1
|
| 840 |
+
del target_latents
|
| 841 |
+
continue
|
| 842 |
+
|
| 843 |
+
lat_len = target_latents.shape[1]
|
| 844 |
+
att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
|
| 845 |
+
|
| 846 |
+
caption = af.stem
|
| 847 |
+
lyrics = "[Instrumental]"
|
| 848 |
+
text_prompt = caption
|
| 849 |
+
|
| 850 |
+
with torch.no_grad():
|
| 851 |
+
text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype)
|
| 852 |
+
lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
|
| 853 |
+
|
| 854 |
+
has_bad = any(
|
| 855 |
+
torch.isnan(t).any() or torch.isinf(t).any()
|
| 856 |
+
for t in [text_hs, lyric_hs]
|
| 857 |
+
)
|
| 858 |
+
if has_bad:
|
| 859 |
+
p1_failed += 1
|
| 860 |
+
del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
|
| 861 |
+
continue
|
| 862 |
+
|
| 863 |
+
tmp_path = out / f"{stem}.tmp.pt"
|
| 864 |
+
torch.save({
|
| 865 |
+
"target_latents": target_latents.squeeze(0).cpu(),
|
| 866 |
+
"attention_mask": att_mask.squeeze(0).cpu(),
|
| 867 |
+
"text_hidden_states": text_hs.cpu(),
|
| 868 |
+
"text_attention_mask": text_mask.cpu(),
|
| 869 |
+
"lyric_hidden_states": lyric_hs.cpu(),
|
| 870 |
+
"lyric_attention_mask": lyric_mask.cpu(),
|
| 871 |
+
"silence_latent": silence_lat.cpu(),
|
| 872 |
+
"latent_length": lat_len,
|
| 873 |
+
"metadata": {
|
| 874 |
+
"audio_path": str(af),
|
| 875 |
+
"filename": af.name,
|
| 876 |
+
"caption": caption,
|
| 877 |
+
"lyrics": lyrics,
|
| 878 |
+
},
|
| 879 |
+
}, tmp_path)
|
| 880 |
+
|
| 881 |
+
del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
|
| 882 |
+
intermediates.append(tmp_path)
|
| 883 |
+
|
| 884 |
+
if progress_callback:
|
| 885 |
+
progress_callback(i + 1, total, f"[Pass 1] {af.name}")
|
| 886 |
+
|
| 887 |
+
except Exception as exc:
|
| 888 |
+
p1_failed += 1
|
| 889 |
+
logger.error("Pass 1 FAIL %s: %s", af.name, exc)
|
| 890 |
+
finally:
|
| 891 |
+
logger.info("Unloading VAE + Text Encoder...")
|
| 892 |
+
unload_models(vae, text_enc, tokenizer, silence_lat)
|
| 893 |
+
|
| 894 |
+
# ---- Pass 2: DIT Encoder ----
|
| 895 |
+
if not intermediates:
|
| 896 |
+
return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)}
|
| 897 |
+
|
| 898 |
+
logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant)
|
| 899 |
+
model = load_model_for_training(checkpoint_dir, variant, device)
|
| 900 |
+
|
| 901 |
+
processed = 0
|
| 902 |
+
p2_failed = 0
|
| 903 |
+
p2_total = len(intermediates)
|
| 904 |
+
|
| 905 |
+
try:
|
| 906 |
+
for i, tmp_path in enumerate(intermediates):
|
| 907 |
+
if cancel_check and cancel_check():
|
| 908 |
+
break
|
| 909 |
+
|
| 910 |
+
try:
|
| 911 |
+
data = torch.load(str(tmp_path), weights_only=True)
|
| 912 |
+
m_device = next(model.parameters()).device
|
| 913 |
+
m_dtype = next(model.parameters()).dtype
|
| 914 |
+
|
| 915 |
+
text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype)
|
| 916 |
+
text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype)
|
| 917 |
+
lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype)
|
| 918 |
+
lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype)
|
| 919 |
+
silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype)
|
| 920 |
+
lat_len = data["latent_length"]
|
| 921 |
+
|
| 922 |
+
enc_hs, enc_mask = run_encoder(
|
| 923 |
+
model, text_hs, text_mask, lyric_hs, lyric_mask,
|
| 924 |
+
str(m_device), m_dtype,
|
| 925 |
+
)
|
| 926 |
+
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 927 |
+
|
| 928 |
+
if silence_lat.dim() == 2:
|
| 929 |
+
silence_lat = silence_lat.unsqueeze(0)
|
| 930 |
+
ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype)
|
| 931 |
+
del silence_lat
|
| 932 |
+
|
| 933 |
+
has_bad = any(
|
| 934 |
+
torch.isnan(t).any() or torch.isinf(t).any()
|
| 935 |
+
for t in [enc_hs, ctx]
|
| 936 |
+
)
|
| 937 |
+
if has_bad:
|
| 938 |
+
p2_failed += 1
|
| 939 |
+
del enc_hs, enc_mask, ctx, data
|
| 940 |
+
continue
|
| 941 |
+
|
| 942 |
+
base_name = tmp_path.name.replace(".tmp.pt", ".pt")
|
| 943 |
+
final_path = out / base_name
|
| 944 |
+
staging_path = out / (base_name + ".__writing__")
|
| 945 |
+
|
| 946 |
+
torch.save({
|
| 947 |
+
"target_latents": data["target_latents"],
|
| 948 |
+
"attention_mask": data["attention_mask"],
|
| 949 |
+
"encoder_hidden_states": enc_hs.squeeze(0).cpu(),
|
| 950 |
+
"encoder_attention_mask": enc_mask.squeeze(0).cpu(),
|
| 951 |
+
"context_latents": ctx.squeeze(0).cpu(),
|
| 952 |
+
"metadata": data.get("metadata", {}),
|
| 953 |
+
}, staging_path)
|
| 954 |
+
os.replace(staging_path, final_path)
|
| 955 |
+
|
| 956 |
+
del enc_hs, enc_mask, ctx, data
|
| 957 |
+
tmp_path.unlink(missing_ok=True)
|
| 958 |
+
processed += 1
|
| 959 |
+
|
| 960 |
+
if progress_callback:
|
| 961 |
+
progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}")
|
| 962 |
+
|
| 963 |
+
except Exception as exc:
|
| 964 |
+
p2_failed += 1
|
| 965 |
+
logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc)
|
| 966 |
+
finally:
|
| 967 |
+
logger.info("Unloading DIT model...")
|
| 968 |
+
unload_models(model)
|
| 969 |
+
|
| 970 |
+
failed = p1_failed + p2_failed
|
| 971 |
+
return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
# ============================================================================
|
| 975 |
+
# TRAINING LOOP (generator for Gradio compatibility)
|
| 976 |
+
# ============================================================================
|
| 977 |
+
|
| 978 |
+
def train_lora_generator(
|
| 979 |
+
dataset_dir: str,
|
| 980 |
+
output_dir: str,
|
| 981 |
+
checkpoint_dir: str,
|
| 982 |
+
epochs: int = 1000,
|
| 983 |
+
lr: float = 3e-4,
|
| 984 |
+
rank: int = 64,
|
| 985 |
+
alpha: int = 128,
|
| 986 |
+
dropout: float = 0.1,
|
| 987 |
+
batch_size: int = 1,
|
| 988 |
+
gradient_accumulation_steps: int = 4,
|
| 989 |
+
warmup_steps: int = 100,
|
| 990 |
+
weight_decay: float = 0.01,
|
| 991 |
+
max_grad_norm: float = 1.0,
|
| 992 |
+
save_every_n_epochs: int = 50,
|
| 993 |
+
seed: int = 42,
|
| 994 |
+
variant: str = "base",
|
| 995 |
+
device: str = "cpu",
|
| 996 |
+
cfg_ratio: float = 0.15,
|
| 997 |
+
timestep_mu: float = -0.4,
|
| 998 |
+
timestep_sigma: float = 1.0,
|
| 999 |
+
target_modules: Optional[List[str]] = None,
|
| 1000 |
+
log_every: int = 10,
|
| 1001 |
+
resume_from: Optional[str] = None,
|
| 1002 |
+
) -> Generator[str, None, None]:
|
| 1003 |
+
"""Run LoRA training, yielding progress strings each epoch.
|
| 1004 |
+
|
| 1005 |
+
This is a generator for Gradio live-update compatibility.
|
| 1006 |
+
Call cancel_training() to stop after the current epoch.
|
| 1007 |
+
"""
|
| 1008 |
+
_training_cancel.clear()
|
| 1009 |
+
train_start = time.time()
|
| 1010 |
+
|
| 1011 |
+
if target_modules is None:
|
| 1012 |
+
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
| 1013 |
+
|
| 1014 |
+
ds_path = Path(dataset_dir)
|
| 1015 |
+
if not ds_path.is_dir():
|
| 1016 |
+
yield f"[FAIL] Dataset directory not found: {ds_path}"
|
| 1017 |
+
return
|
| 1018 |
+
|
| 1019 |
+
out_path = Path(output_dir)
|
| 1020 |
+
out_path.mkdir(parents=True, exist_ok=True)
|
| 1021 |
+
|
| 1022 |
+
yield "[INFO] Loading model..."
|
| 1023 |
+
|
| 1024 |
+
try:
|
| 1025 |
+
model = load_model_for_training(checkpoint_dir, variant, device)
|
| 1026 |
+
except Exception as exc:
|
| 1027 |
+
yield f"[FAIL] Model load failed: {exc}"
|
| 1028 |
+
return
|
| 1029 |
+
|
| 1030 |
+
# float32 on CPU (bfloat16 deadlocks)
|
| 1031 |
+
dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
|
| 1032 |
+
model = model.to(dtype=dtype)
|
| 1033 |
+
|
| 1034 |
+
yield "[INFO] Injecting LoRA..."
|
| 1035 |
+
|
| 1036 |
+
lora_cfg = LoRAConfig(
|
| 1037 |
+
r=rank, alpha=alpha, dropout=dropout,
|
| 1038 |
+
target_modules=target_modules, bias="none",
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
try:
|
| 1042 |
+
model, info = inject_lora(model, lora_cfg)
|
| 1043 |
+
except Exception as exc:
|
| 1044 |
+
yield f"[FAIL] LoRA injection failed: {exc}"
|
| 1045 |
+
unload_models(model)
|
| 1046 |
+
return
|
| 1047 |
+
|
| 1048 |
+
yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
|
| 1049 |
+
|
| 1050 |
+
# Gradient checkpointing + cache disable
|
| 1051 |
+
force_disable_cache(model.decoder)
|
| 1052 |
+
ckpt_ok = enable_gradient_checkpointing(model.decoder)
|
| 1053 |
+
force_input_grads = ckpt_ok
|
| 1054 |
+
if ckpt_ok:
|
| 1055 |
+
yield "[INFO] Gradient checkpointing enabled"
|
| 1056 |
+
|
| 1057 |
+
# Dataset
|
| 1058 |
+
dataset = TensorDataset(dataset_dir)
|
| 1059 |
+
if len(dataset) == 0:
|
| 1060 |
+
yield "[FAIL] No valid .pt files found in dataset directory"
|
| 1061 |
+
unload_models(model)
|
| 1062 |
+
return
|
| 1063 |
+
|
| 1064 |
+
yield f"[OK] Loaded {len(dataset)} preprocessed samples"
|
| 1065 |
+
|
| 1066 |
+
loader = DataLoader(
|
| 1067 |
+
dataset, batch_size=batch_size, shuffle=True,
|
| 1068 |
+
num_workers=0, collate_fn=_collate_batch, drop_last=False,
|
| 1069 |
+
)
|
| 1070 |
+
|
| 1071 |
+
# Optimizer & scheduler
|
| 1072 |
+
torch.manual_seed(seed)
|
| 1073 |
+
random.seed(seed)
|
| 1074 |
+
|
| 1075 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 1076 |
+
if not trainable_params:
|
| 1077 |
+
yield "[FAIL] No trainable parameters found"
|
| 1078 |
+
unload_models(model)
|
| 1079 |
+
return
|
| 1080 |
+
|
| 1081 |
+
optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay)
|
| 1082 |
+
steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps))
|
| 1083 |
+
total_steps = steps_per_epoch * epochs
|
| 1084 |
+
scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr)
|
| 1085 |
+
|
| 1086 |
+
yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
|
| 1087 |
+
yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
|
| 1088 |
+
|
| 1089 |
+
# Null condition embedding for CFG dropout
|
| 1090 |
+
null_cond = getattr(model, "null_condition_emb", None)
|
| 1091 |
+
|
| 1092 |
+
# Resume checkpoint
|
| 1093 |
+
start_epoch = 0
|
| 1094 |
+
global_step = 0
|
| 1095 |
+
|
| 1096 |
+
if resume_from and Path(resume_from).exists():
|
| 1097 |
+
try:
|
| 1098 |
+
yield f"[INFO] Resuming from {resume_from}"
|
| 1099 |
+
ckpt_dir = Path(resume_from)
|
| 1100 |
+
if ckpt_dir.is_file():
|
| 1101 |
+
ckpt_dir = ckpt_dir.parent
|
| 1102 |
+
|
| 1103 |
+
# Load adapter weights
|
| 1104 |
+
aw = ckpt_dir / "adapter_model.safetensors"
|
| 1105 |
+
if aw.exists():
|
| 1106 |
+
from safetensors.torch import load_file
|
| 1107 |
+
state = load_file(str(aw))
|
| 1108 |
+
decoder = model.decoder
|
| 1109 |
+
while hasattr(decoder, "_forward_module"):
|
| 1110 |
+
decoder = decoder._forward_module
|
| 1111 |
+
decoder.load_state_dict(state, strict=False)
|
| 1112 |
+
|
| 1113 |
+
# Load training state
|
| 1114 |
+
ts = ckpt_dir / "training_state.pt"
|
| 1115 |
+
if ts.exists():
|
| 1116 |
+
tstate = torch.load(str(ts), map_location=device, weights_only=True)
|
| 1117 |
+
start_epoch = tstate.get("epoch", 0)
|
| 1118 |
+
global_step = tstate.get("global_step", 0)
|
| 1119 |
+
if "optimizer_state_dict" in tstate:
|
| 1120 |
+
try:
|
| 1121 |
+
optimizer.load_state_dict(tstate["optimizer_state_dict"])
|
| 1122 |
+
except Exception:
|
| 1123 |
+
pass
|
| 1124 |
+
if "scheduler_state_dict" in tstate:
|
| 1125 |
+
try:
|
| 1126 |
+
scheduler.load_state_dict(tstate["scheduler_state_dict"])
|
| 1127 |
+
except Exception:
|
| 1128 |
+
pass
|
| 1129 |
+
|
| 1130 |
+
yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}"
|
| 1131 |
+
except Exception as exc:
|
| 1132 |
+
yield f"[WARN] Checkpoint load failed: {exc}, starting fresh"
|
| 1133 |
+
start_epoch = 0
|
| 1134 |
+
global_step = 0
|
| 1135 |
+
|
| 1136 |
+
# Training loop
|
| 1137 |
+
model.decoder.train()
|
| 1138 |
+
acc_step = 0
|
| 1139 |
+
acc_loss = 0.0
|
| 1140 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1141 |
+
|
| 1142 |
+
best_loss = float("inf")
|
| 1143 |
+
best_epoch = 0
|
| 1144 |
+
consecutive_nan = 0
|
| 1145 |
+
MAX_NAN = 10
|
| 1146 |
+
|
| 1147 |
+
for epoch in range(start_epoch, epochs):
|
| 1148 |
+
# Cancel check
|
| 1149 |
+
if _training_cancel.is_set():
|
| 1150 |
+
_training_cancel.clear()
|
| 1151 |
+
early_path = str(out_path / "early_exit")
|
| 1152 |
+
model.decoder.eval()
|
| 1153 |
+
save_lora_adapter(model, early_path)
|
| 1154 |
+
model.decoder.train()
|
| 1155 |
+
yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
|
| 1156 |
+
yield "[DONE]"
|
| 1157 |
+
unload_models(model)
|
| 1158 |
+
return
|
| 1159 |
+
|
| 1160 |
+
# Timeout check
|
| 1161 |
+
elapsed = time.time() - train_start
|
| 1162 |
+
if elapsed > MAX_TRAINING_TIME:
|
| 1163 |
+
early_path = str(out_path / "timeout_exit")
|
| 1164 |
+
model.decoder.eval()
|
| 1165 |
+
save_lora_adapter(model, early_path)
|
| 1166 |
+
yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
|
| 1167 |
+
yield "[DONE]"
|
| 1168 |
+
unload_models(model)
|
| 1169 |
+
return
|
| 1170 |
+
|
| 1171 |
+
epoch_loss = 0.0
|
| 1172 |
+
num_updates = 0
|
| 1173 |
+
epoch_start = time.time()
|
| 1174 |
+
|
| 1175 |
+
for batch in loader:
|
| 1176 |
+
# Forward
|
| 1177 |
+
nb = device != "cpu"
|
| 1178 |
+
tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 1179 |
+
att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 1180 |
+
enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
|
| 1181 |
+
enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
|
| 1182 |
+
ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
|
| 1183 |
+
|
| 1184 |
+
bsz = tgt.shape[0]
|
| 1185 |
+
|
| 1186 |
+
# CFG dropout
|
| 1187 |
+
if null_cond is not None and cfg_ratio > 0:
|
| 1188 |
+
enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio)
|
| 1189 |
+
|
| 1190 |
+
# Timestep sampling
|
| 1191 |
+
t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma)
|
| 1192 |
+
|
| 1193 |
+
# Flow matching noise
|
| 1194 |
+
x1 = torch.randn_like(tgt)
|
| 1195 |
+
x0 = tgt
|
| 1196 |
+
t_ = t.unsqueeze(-1).unsqueeze(-1)
|
| 1197 |
+
xt = t_ * x1 + (1.0 - t_) * x0
|
| 1198 |
+
|
| 1199 |
+
if force_input_grads:
|
| 1200 |
+
xt = xt.requires_grad_(True)
|
| 1201 |
+
|
| 1202 |
+
# Decoder forward
|
| 1203 |
+
dec_out = model.decoder(
|
| 1204 |
+
hidden_states=xt,
|
| 1205 |
+
timestep=t,
|
| 1206 |
+
timestep_r=t,
|
| 1207 |
+
attention_mask=att,
|
| 1208 |
+
encoder_hidden_states=enc_hs,
|
| 1209 |
+
encoder_attention_mask=enc_mask,
|
| 1210 |
+
context_latents=ctx,
|
| 1211 |
+
)
|
| 1212 |
+
|
| 1213 |
+
flow = x1 - x0
|
| 1214 |
+
loss = F.mse_loss(dec_out[0], flow)
|
| 1215 |
+
loss = loss.float() # fp32 for stable backward
|
| 1216 |
+
|
| 1217 |
+
# NaN guard
|
| 1218 |
+
if torch.isnan(loss) or torch.isinf(loss):
|
| 1219 |
+
consecutive_nan += 1
|
| 1220 |
+
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 1221 |
+
if consecutive_nan >= MAX_NAN:
|
| 1222 |
+
yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
|
| 1223 |
+
unload_models(model)
|
| 1224 |
+
return
|
| 1225 |
+
if acc_step > 0:
|
| 1226 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1227 |
+
acc_loss = 0.0
|
| 1228 |
+
acc_step = 0
|
| 1229 |
+
continue
|
| 1230 |
+
consecutive_nan = 0
|
| 1231 |
+
|
| 1232 |
+
loss = loss / gradient_accumulation_steps
|
| 1233 |
+
loss.backward()
|
| 1234 |
+
acc_loss += loss.item()
|
| 1235 |
+
del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
|
| 1236 |
+
acc_step += 1
|
| 1237 |
+
|
| 1238 |
+
if acc_step >= gradient_accumulation_steps:
|
| 1239 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 1240 |
+
optimizer.step()
|
| 1241 |
+
scheduler.step()
|
| 1242 |
+
global_step += 1
|
| 1243 |
+
|
| 1244 |
+
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
|
| 1245 |
+
|
| 1246 |
+
if global_step % log_every == 0:
|
| 1247 |
+
current_lr = scheduler.get_last_lr()[0]
|
| 1248 |
+
yield (
|
| 1249 |
+
f"Epoch {epoch + 1}/{epochs}, "
|
| 1250 |
+
f"Step {global_step}, "
|
| 1251 |
+
f"Loss: {avg_loss:.4f}, "
|
| 1252 |
+
f"LR: {current_lr:.2e}"
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1256 |
+
epoch_loss += avg_loss
|
| 1257 |
+
num_updates += 1
|
| 1258 |
+
acc_loss = 0.0
|
| 1259 |
+
acc_step = 0
|
| 1260 |
+
|
| 1261 |
+
# Flush remainder
|
| 1262 |
+
if acc_step > 0:
|
| 1263 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
|
| 1264 |
+
optimizer.step()
|
| 1265 |
+
scheduler.step()
|
| 1266 |
+
global_step += 1
|
| 1267 |
+
avg_loss = acc_loss * gradient_accumulation_steps / acc_step
|
| 1268 |
+
optimizer.zero_grad(set_to_none=True)
|
| 1269 |
+
epoch_loss += avg_loss
|
| 1270 |
+
num_updates += 1
|
| 1271 |
+
acc_loss = 0.0
|
| 1272 |
+
acc_step = 0
|
| 1273 |
+
|
| 1274 |
+
epoch_time = time.time() - epoch_start
|
| 1275 |
+
avg_epoch_loss = epoch_loss / max(num_updates, 1)
|
| 1276 |
+
|
| 1277 |
+
is_best = avg_epoch_loss < best_loss - 0.001
|
| 1278 |
+
if is_best:
|
| 1279 |
+
best_loss = avg_epoch_loss
|
| 1280 |
+
best_epoch = epoch + 1
|
| 1281 |
+
|
| 1282 |
+
best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else ""
|
| 1283 |
+
yield (
|
| 1284 |
+
f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, "
|
| 1285 |
+
f"Loss: {avg_epoch_loss:.4f}{best_str}"
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
# Save best
|
| 1289 |
+
if is_best and epoch + 1 >= 10:
|
| 1290 |
+
best_path = str(out_path / "best")
|
| 1291 |
+
model.decoder.eval()
|
| 1292 |
+
save_lora_adapter(model, best_path)
|
| 1293 |
+
model.decoder.train()
|
| 1294 |
+
yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})"
|
| 1295 |
+
|
| 1296 |
+
# Periodic checkpoint
|
| 1297 |
+
if (epoch + 1) % save_every_n_epochs == 0:
|
| 1298 |
+
ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}")
|
| 1299 |
+
model.decoder.eval()
|
| 1300 |
+
save_lora_adapter(model, ckpt_path)
|
| 1301 |
+
|
| 1302 |
+
tstate = {
|
| 1303 |
+
"epoch": epoch + 1,
|
| 1304 |
+
"global_step": global_step,
|
| 1305 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 1306 |
+
"scheduler_state_dict": scheduler.state_dict(),
|
| 1307 |
+
}
|
| 1308 |
+
os.makedirs(ckpt_path, exist_ok=True)
|
| 1309 |
+
torch.save(tstate, str(Path(ckpt_path) / "training_state.pt"))
|
| 1310 |
+
model.decoder.train()
|
| 1311 |
+
yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
|
| 1312 |
+
|
| 1313 |
+
# Sanity check
|
| 1314 |
+
if global_step == 0:
|
| 1315 |
+
yield "[FAIL] Training completed 0 steps -- no batches processed"
|
| 1316 |
+
unload_models(model)
|
| 1317 |
+
return
|
| 1318 |
+
|
| 1319 |
+
# Final save
|
| 1320 |
+
final_path = str(out_path / "final")
|
| 1321 |
+
model.decoder.eval()
|
| 1322 |
+
save_lora_adapter(model, final_path)
|
| 1323 |
+
|
| 1324 |
+
final_loss = avg_epoch_loss if num_updates > 0 else 0.0
|
| 1325 |
+
best_note = ""
|
| 1326 |
+
if best_epoch > 0 and Path(out_path / "best").exists():
|
| 1327 |
+
best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})"
|
| 1328 |
+
yield (
|
| 1329 |
+
f"[OK] Training complete! LoRA saved to {final_path}{best_note}\n"
|
| 1330 |
+
f" For inference, set your LoRA path to: {final_path}"
|
| 1331 |
+
)
|
| 1332 |
+
yield "[DONE]"
|
| 1333 |
+
unload_models(model)
|
| 1334 |
+
|
| 1335 |
+
|
| 1336 |
+
# ============================================================================
|
| 1337 |
+
# ADAPTER LISTING
|
| 1338 |
+
# ============================================================================
|
| 1339 |
+
|
| 1340 |
+
def get_trained_loras(adapter_dir: str) -> List[str]:
|
| 1341 |
+
"""List all saved LoRA adapter directories under adapter_dir."""
|
| 1342 |
+
result = []
|
| 1343 |
+
base = Path(adapter_dir)
|
| 1344 |
+
if not base.is_dir():
|
| 1345 |
+
return result
|
| 1346 |
+
|
| 1347 |
+
for root, dirs, files in os.walk(str(base)):
|
| 1348 |
+
for f in files:
|
| 1349 |
+
if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"):
|
| 1350 |
+
result.append(root)
|
| 1351 |
+
break
|
| 1352 |
+
|
| 1353 |
+
return sorted(set(result))
|