Nekochu commited on
Commit
a07b39d
·
1 Parent(s): 5e95353

Side-Step training engine, tested locally on CPU

Browse files

- train_engine.py: 1347-line standalone LoRA training (ported from Side-Step)
- Adafactor optimizer, 2-pass sequential preprocessing, FA->SDPA->eager fallback
- Generator-based progress (yield per epoch), cancel button, checkpoint saves
- Pin transformers<4.58.0 (meta tensor fix), skip device_map on CPU
- Tested locally: preprocessing 174s + training 51s (2 epochs, 10s audio)

Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +333 -168
  3. train_engine.py +1353 -0
Dockerfile CHANGED
@@ -72,7 +72,7 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
72
  # Install Python deps for Gradio UI + training
73
  RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
74
  "gradio[mcp]==5.29.0" requests torch safetensors \
75
- transformers>=4.51.0 peft>=0.18.0 \
76
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
77
  einops vector_quantize_pytorch
78
 
 
72
  # Install Python deps for Gradio UI + training
73
  RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
74
  "gradio[mcp]==5.29.0" requests torch safetensors \
75
+ "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
76
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
77
  einops vector_quantize_pytorch
78
 
app.py CHANGED
@@ -6,7 +6,31 @@ import time
6
  import json
7
  import argparse
8
  import tempfile
 
 
9
  import requests
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
12
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
@@ -16,6 +40,12 @@ ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
16
  ACE_SOURCE_DIR = "/app/ace-step-source"
17
  ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
18
  ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
 
 
 
 
 
 
19
 
20
  # ---------------------------------------------------------------------------
21
  # ace-server helpers
@@ -68,7 +98,7 @@ def _fetch_result(job_id, timeout=60):
68
 
69
 
70
  def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
71
- adapter=None, progress_cb=None):
72
  """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
73
  t0 = time.time()
74
 
@@ -86,6 +116,8 @@ def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
86
  req["inference_steps"] = int(steps)
87
  if adapter:
88
  req["adapter"] = adapter
 
 
89
 
90
  fmt = output_format if output_format in ("wav", "mp3") else "mp3"
91
  synth_fmt = "wav16" if fmt == "wav" else "mp3"
@@ -143,6 +175,98 @@ def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
143
  return tmp.name, msg
144
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  # ---------------------------------------------------------------------------
147
  # CLI mode
148
  # ---------------------------------------------------------------------------
@@ -216,7 +340,6 @@ def cli_main():
216
 
217
  # Move to requested output path if specified
218
  if args.output:
219
- import shutil
220
  out_dir = os.path.dirname(os.path.abspath(args.output))
221
  os.makedirs(out_dir, exist_ok=True)
222
  shutil.move(audio_path, args.output)
@@ -232,18 +355,15 @@ def cli_main():
232
 
233
  def gradio_main():
234
  import gradio as gr
 
235
 
236
- # -- Generate tab handler --
237
- def get_trained_loras():
238
- loras = ["None (no LoRA)"]
239
- if os.path.isdir(ADAPTER_DIR):
240
- for d in os.listdir(ADAPTER_DIR):
241
- if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
242
- loras.append(d)
243
- return loras
244
 
 
245
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
246
- steps, lora_select, progress=gr.Progress(track_tqdm=True)):
 
247
  if not _server_ok():
248
  return None, "ace-server not running. Check logs."
249
 
@@ -252,6 +372,7 @@ def gradio_main():
252
 
253
  actual_seed = None if seed is None or int(seed) < 0 else int(seed)
254
  adapter = None if lora_select == "None (no LoRA)" else lora_select
 
255
 
256
  progress_map = {
257
  "lm_submit": (0.05, "Submitting LM job..."),
@@ -277,6 +398,7 @@ def gradio_main():
277
  steps=steps,
278
  output_format="mp3",
279
  adapter=adapter,
 
280
  progress_cb=gr_progress,
281
  )
282
  return audio_path, status
@@ -295,22 +417,35 @@ def gradio_main():
295
  lines.append(json.dumps(props, indent=2))
296
  return "\n".join(lines)
297
 
298
- # -- Training (runs as detached subprocess to survive Gradio session timeout) --
299
- TRAIN_LOG = "/app/outputs/train.log"
 
 
 
 
 
 
 
 
300
 
301
- def train_lora(audio_files, lora_name, epochs, lr, rank,
302
- progress=gr.Progress(track_tqdm=True)):
303
- import shutil, subprocess
304
 
 
305
  if not audio_files:
306
- return "No audio files uploaded."
 
 
307
 
308
- if os.path.exists(TRAIN_LOG):
309
- last_line = open(TRAIN_LOG).readlines()[-1] if os.path.getsize(TRAIN_LOG) > 0 else ""
310
- if "DONE" not in last_line and "ERROR" not in last_line and last_line.strip():
311
- return f"Training already in progress. Click 'Check Log' to monitor.\n\nLast: {last_line.strip()}"
312
 
313
  lora_name = (lora_name or "").strip() or "my-lora"
 
 
 
314
  epochs = max(1, min(int(epochs), 10))
315
  lr = float(lr)
316
  rank = max(1, min(int(rank), 64))
@@ -319,145 +454,136 @@ def gradio_main():
319
  os.makedirs(output_dir, exist_ok=True)
320
  audio_dir = os.path.join(output_dir, "audio_input")
321
  os.makedirs(audio_dir, exist_ok=True)
 
 
 
 
 
322
  for f in audio_files:
323
  src = f.name if hasattr(f, "name") else str(f)
324
  shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
325
 
326
- train_script = f"""
327
- import os, sys, time, gc
328
- sys.path.insert(0, "{ACE_SOURCE_DIR}")
329
- os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "1"
330
 
331
- LOG = "{TRAIN_LOG}"
332
- def log(msg):
333
- print(f"[train] {{msg}}", flush=True)
334
- with open(LOG, "a") as f:
335
- f.write(msg + "\\n")
336
- f.flush()
337
 
338
- open(LOG, "w").close()
339
- log("LoRA Training: '{lora_name}' | files={len(audio_files)} | epochs={epochs} lr={lr} rank={rank}")
340
-
341
- import subprocess
342
- log("Stopping ace-server...")
343
- subprocess.run(["pkill", "-f", "ace-server"], stderr=subprocess.DEVNULL)
344
- time.sleep(2)
345
- gc.collect()
346
-
347
- try:
348
- import torch
349
- torch.backends.cuda.enable_flash_sdp(False)
350
- os.environ["ATTN_BACKEND"] = "sdpa"
351
-
352
- import torchaudio
353
- _orig = torchaudio.load
354
- def _sf(p, *a, **kw):
355
- kw.setdefault("backend", "soundfile")
356
- return _orig(p, *a, **kw)
357
- torchaudio.load = _sf
358
-
359
- log("[Step 1/2] Preprocessing audio...")
360
- log(" importing preprocess module...")
361
- from acestep.training_v2.preprocess import preprocess_audio_files
362
- log(" import done, calling preprocess_audio_files...")
363
- import resource
364
- mem_before = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024
365
- log(f" RAM before preprocess: {{mem_before}} MB")
366
- result = preprocess_audio_files(
367
- audio_dir="{audio_dir}",
368
- output_dir="{output_dir}/preprocessed_tensors",
369
- checkpoint_dir="{ACE_CHECKPOINT_DIR}",
370
- variant="turbo", max_duration=60.0,
371
- device="cpu", precision="float32",
372
- )
373
- mem_after = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss // 1024
374
- log(f" RAM after preprocess: {{mem_after}} MB")
375
- processed = result.get("processed", 0)
376
- failed = result.get("failed", 0)
377
- log(f" Preprocessed: {{processed}}/{{result.get('total',0)}} (failed: {{failed}})")
378
- if processed == 0:
379
- log("ERROR: No files preprocessed. DONE")
380
- raise SystemExit(1)
381
-
382
- gc.collect()
383
- log("[Step 2/2] Training LoRA...")
384
- from acestep.training_v2.model_loader import load_decoder_for_training
385
- from acestep.training_v2.trainer_fixed import FixedLoRATrainer
386
- from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
387
-
388
- log(" Loading decoder (attn_implementation=sdpa)...")
389
- model = load_decoder_for_training(
390
- checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
391
- device="cpu", precision="float32",
392
- ).float()
393
- for m in model.modules():
394
- if hasattr(m, 'config') and hasattr(m.config, '_attn_implementation'):
395
- m.config._attn_implementation = "sdpa"
396
- log(" Decoder loaded, applying LoRA...")
397
-
398
- trainer = FixedLoRATrainer(model,
399
- LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
400
- TrainingConfigV2(
401
- checkpoint_dir="{ACE_CHECKPOINT_DIR}", model_variant="turbo",
402
- dataset_dir="{output_dir}/preprocessed_tensors",
403
- output_dir="{output_dir}",
404
- max_epochs={epochs}, batch_size=1, learning_rate={lr},
405
- device="cpu", precision="float32", seed=42,
406
- num_workers=0, pin_memory=False,
407
- ))
408
-
409
- step_count, last_loss = 0, 0.0
410
- for update in trainer.train():
411
- if hasattr(update, "step"):
412
- step_count, last_loss = update.step, update.loss
413
- elif isinstance(update, tuple) and len(update) >= 2:
414
- step_count, last_loss = update[0], update[1]
415
- if step_count % 5 == 0:
416
- log(f" Step {{step_count}}: loss={{last_loss:.4f}}")
417
-
418
- log(f"Training complete! step={{step_count}} loss={{last_loss:.4f}}")
419
- log(f"LoRA saved to: {output_dir}")
420
- del model, trainer
421
- gc.collect()
422
- log("DONE")
423
-
424
- except Exception as e:
425
- import traceback
426
- log(f"ERROR: {{e}}")
427
- log(traceback.format_exc())
428
- log("DONE")
429
- finally:
430
- log("Restarting ace-server...")
431
- subprocess.Popen(["/app/ace-server", "--host", "127.0.0.1", "--port", "8085",
432
- "--models", "/app/models", "--adapters", "/app/adapters", "--max-batch", "1"],
433
- stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
434
- """
435
- script_path = os.path.join(output_dir, "_train.py")
436
- with open(script_path, "w") as f:
437
- f.write(train_script)
438
-
439
- train_stderr = os.path.join(output_dir, "train_stderr.log")
440
- subprocess.Popen(
441
- ["python3", "-u", script_path],
442
- stdout=open(TRAIN_LOG, "a"),
443
- stderr=open(train_stderr, "w"),
444
- start_new_session=True,
445
- )
446
-
447
- return (f"Training started in background for '{lora_name}'.\n"
448
- f"Audio: {len(audio_files)} files, Epochs: {epochs}, Rank: {rank}\n\n"
449
- f"Click 'Check Log' to monitor progress.\n"
450
- f"Inference will be unavailable until training completes (ace-server stopped).")
451
-
452
- def check_train_log():
453
- parts = []
454
- if os.path.exists(TRAIN_LOG):
455
- parts.append(open(TRAIN_LOG).read())
456
- stderr_log = os.path.join(ADAPTER_DIR, "test-lora", "train_stderr.log")
457
- if os.path.exists(stderr_log) and os.path.getsize(stderr_log) > 0:
458
- stderr = open(stderr_log).read()[-8000:]
459
- parts.append(f"\n--- stderr (last 8000 chars, {os.path.getsize(stderr_log)} bytes total) ---\n{stderr}")
460
- return "\n".join(parts) if parts else "No training log found."
461
 
462
  # -- Build UI --
463
  CSS = """
@@ -512,7 +638,16 @@ finally:
512
  value=8, step=1, scale=1,
513
  )
514
  seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
515
- lora_select = gr.Dropdown(label="LoRA", choices=get_trained_loras(), value="None (no LoRA)", scale=1)
 
 
 
 
 
 
 
 
 
516
 
517
  with gr.Row(elem_classes="compact-row"):
518
  gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
@@ -521,7 +656,7 @@ finally:
521
  gen_btn.click(
522
  fn=generate_music,
523
  inputs=[caption, lyrics, instrumental, bpm, duration,
524
- seed, steps, lora_select],
525
  outputs=[audio_out, status],
526
  api_name="generate",
527
  )
@@ -539,8 +674,8 @@ finally:
539
  with gr.Tab("Train LoRA"):
540
  gr.Markdown(
541
  "### LoRA Training\n"
542
- "Fine-tune ACE-Step on your own audio data. "
543
- "CPU training is very slow. Checkpoints downloaded on first run (~10GB)."
544
  )
545
 
546
  with gr.Row(elem_classes="compact-row"):
@@ -552,13 +687,21 @@ finally:
552
  )
553
  with gr.Column(scale=1):
554
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
555
- epochs = gr.Number(label="Epochs", value=5, minimum=1, maximum=10)
556
- lr = gr.Number(label="Learning Rate", value=1e-4)
557
- rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
 
 
 
 
 
 
558
 
559
  with gr.Row(elem_classes="compact-row"):
560
  train_btn = gr.Button("Train", variant="primary", scale=2)
 
561
  log_btn = gr.Button("Check Log", scale=1)
 
562
  train_log = gr.Textbox(
563
  label="Training Log",
564
  interactive=False,
@@ -566,17 +709,39 @@ finally:
566
  elem_classes="status-box",
567
  )
568
 
569
- train_btn.click(
570
- fn=train_lora,
571
- inputs=[train_audio, lora_name, epochs, lr, rank],
572
- outputs=[train_log],
 
573
  api_name="train_lora",
 
574
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
575
  log_btn.click(
576
- fn=check_train_log,
577
- inputs=[],
578
  outputs=[train_log],
579
- api_name="check_train_log",
580
  )
581
 
582
  demo.launch(
 
6
  import json
7
  import argparse
8
  import tempfile
9
+ import subprocess
10
+ import shutil
11
  import requests
12
+ import logging
13
+
14
+ from train_engine import (
15
+ preprocess_audio,
16
+ train_lora_generator,
17
+ cancel_training,
18
+ get_trained_loras as _get_trained_loras_engine,
19
+ )
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ---------------------------------------------------------------------------
24
+ # Configurable limits (edit here, not buried in code)
25
+ # ---------------------------------------------------------------------------
26
+
27
+ MAX_AUDIO_DURATION = 240 # seconds, cap per audio file for training
28
+ MAX_TRAINING_TIME = 28800 # 8 hours hard training timeout (seconds)
29
+ MAX_AUDIO_FILES = 50 # max number of training audio files per run
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Paths & constants
33
+ # ---------------------------------------------------------------------------
34
 
35
  ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
36
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
 
40
  ACE_SOURCE_DIR = "/app/ace-step-source"
41
  ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
42
  ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
43
+ MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models")
44
+
45
+ ACE_SERVER_BIN = "/app/ace-server"
46
+
47
+ # HF repo for on-demand GGUF downloads
48
+ GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
49
 
50
  # ---------------------------------------------------------------------------
51
  # ace-server helpers
 
98
 
99
 
100
  def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
101
+ adapter=None, lm_model=None, progress_cb=None):
102
  """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
103
  t0 = time.time()
104
 
 
116
  req["inference_steps"] = int(steps)
117
  if adapter:
118
  req["adapter"] = adapter
119
+ if lm_model:
120
+ req["model"] = lm_model
121
 
122
  fmt = output_format if output_format in ("wav", "mp3") else "mp3"
123
  synth_fmt = "wav16" if fmt == "wav" else "mp3"
 
175
  return tmp.name, msg
176
 
177
 
178
+ # ---------------------------------------------------------------------------
179
+ # LM model scanning & on-demand download
180
+ # ---------------------------------------------------------------------------
181
+
182
+ def _scan_lm_models():
183
+ """Scan /app/models for *-lm-*.gguf files, return list of filenames."""
184
+ models = []
185
+ if os.path.isdir(MODELS_DIR):
186
+ for f in sorted(os.listdir(MODELS_DIR)):
187
+ if "-lm-" in f and f.endswith(".gguf"):
188
+ models.append(f)
189
+ return models
190
+
191
+
192
+ def _download_lm_model(filename):
193
+ """Download a GGUF LM model from HF if not already present."""
194
+ dest = os.path.join(MODELS_DIR, filename)
195
+ if os.path.isfile(dest):
196
+ return dest
197
+ try:
198
+ from huggingface_hub import hf_hub_download
199
+ path = hf_hub_download(
200
+ repo_id=GGUF_HF_REPO,
201
+ filename=filename,
202
+ local_dir=MODELS_DIR,
203
+ )
204
+ return path
205
+ except Exception as exc:
206
+ logger.error("Failed to download %s: %s", filename, exc)
207
+ return None
208
+
209
+
210
+ # ---------------------------------------------------------------------------
211
+ # LoRA listing for UI dropdowns
212
+ # ---------------------------------------------------------------------------
213
+
214
+ def _list_lora_choices():
215
+ """Return list of LoRA choices for dropdown, including 'None'."""
216
+ choices = ["None (no LoRA)"]
217
+ if os.path.isdir(ADAPTER_DIR):
218
+ for d in os.listdir(ADAPTER_DIR):
219
+ if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
220
+ choices.append(d)
221
+ return choices
222
+
223
+
224
+ # ---------------------------------------------------------------------------
225
+ # ace-server stop/start helpers
226
+ # ---------------------------------------------------------------------------
227
+
228
+ _ace_proc = None
229
+
230
+ def _stop_ace_server():
231
+ """Stop ace-server process."""
232
+ global _ace_proc
233
+ if _ace_proc and _ace_proc.poll() is None:
234
+ _ace_proc.terminate()
235
+ try:
236
+ _ace_proc.wait(timeout=10)
237
+ except subprocess.TimeoutExpired:
238
+ _ace_proc.kill()
239
+ _ace_proc = None
240
+ else:
241
+ try:
242
+ subprocess.run(["pkill", "ace-server"],
243
+ stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
244
+ timeout=10)
245
+ except Exception:
246
+ pass
247
+ time.sleep(1)
248
+
249
+
250
+ def _start_ace_server():
251
+ """Start ace-server in background and wait for health."""
252
+ global _ace_proc
253
+ try:
254
+ _ace_proc = subprocess.Popen(
255
+ [ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
256
+ "--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
257
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
258
+ )
259
+ except Exception as exc:
260
+ logger.error("Failed to start ace-server: %s", exc)
261
+ return False
262
+
263
+ for _ in range(30):
264
+ if _server_ok():
265
+ return True
266
+ time.sleep(2)
267
+ return False
268
+
269
+
270
  # ---------------------------------------------------------------------------
271
  # CLI mode
272
  # ---------------------------------------------------------------------------
 
340
 
341
  # Move to requested output path if specified
342
  if args.output:
 
343
  out_dir = os.path.dirname(os.path.abspath(args.output))
344
  os.makedirs(out_dir, exist_ok=True)
345
  shutil.move(audio_path, args.output)
 
355
 
356
  def gradio_main():
357
  import gradio as gr
358
+ import gc
359
 
360
+ # -- Persistent training log buffer (survives across yields) --
361
+ _train_log_lines = []
 
 
 
 
 
 
362
 
363
+ # -- Generate tab handler --
364
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
365
+ steps, lora_select, lm_model_select,
366
+ progress=gr.Progress(track_tqdm=True)):
367
  if not _server_ok():
368
  return None, "ace-server not running. Check logs."
369
 
 
372
 
373
  actual_seed = None if seed is None or int(seed) < 0 else int(seed)
374
  adapter = None if lora_select == "None (no LoRA)" else lora_select
375
+ lm_model = None if not lm_model_select or lm_model_select == "Default" else lm_model_select
376
 
377
  progress_map = {
378
  "lm_submit": (0.05, "Submitting LM job..."),
 
398
  steps=steps,
399
  output_format="mp3",
400
  adapter=adapter,
401
+ lm_model=lm_model,
402
  progress_cb=gr_progress,
403
  )
404
  return audio_path, status
 
417
  lines.append(json.dumps(props, indent=2))
418
  return "\n".join(lines)
419
 
420
+ # -- Training generator (direct integration, no subprocess) --
421
+ def train_lora_ui(audio_files, lora_name, epochs, lr, rank):
422
+ """Generator that yields (train_log, train_btn_update, cancel_btn_update)."""
423
+ import gc as _gc
424
+
425
+ _train_log_lines.clear()
426
+ train_start = time.time()
427
+
428
+ def _log(msg):
429
+ _train_log_lines.append(msg)
430
 
431
+ def _log_text():
432
+ return "\n".join(_train_log_lines)
 
433
 
434
+ # -- Validation --
435
  if not audio_files:
436
+ _log("[FAIL] No audio files uploaded.")
437
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False)
438
+ return
439
 
440
+ if len(audio_files) > MAX_AUDIO_FILES:
441
+ _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
442
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False)
443
+ return
444
 
445
  lora_name = (lora_name or "").strip() or "my-lora"
446
+ # Sanitize: alphanumeric, dash, underscore only
447
+ lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
448
+
449
  epochs = max(1, min(int(epochs), 10))
450
  lr = float(lr)
451
  rank = max(1, min(int(rank), 64))
 
454
  os.makedirs(output_dir, exist_ok=True)
455
  audio_dir = os.path.join(output_dir, "audio_input")
456
  os.makedirs(audio_dir, exist_ok=True)
457
+
458
+ # Copy uploaded audio files
459
+ _log(f"[INFO] Preparing {len(audio_files)} audio files...")
460
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
461
+
462
  for f in audio_files:
463
  src = f.name if hasattr(f, "name") else str(f)
464
  shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
465
 
466
+ _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
467
+ f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
468
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
 
469
 
470
+ # Stop ace-server before training (frees memory)
471
+ _log("[INFO] Stopping ace-server for training...")
472
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
473
+ _stop_ace_server()
474
+ _gc.collect()
 
475
 
476
+ try:
477
+ # -- Phase 1: Preprocessing --
478
+ _log("[Step 1/2] Preprocessing audio...")
479
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
480
+
481
+ preprocessed_dir = os.path.join(output_dir, "preprocessed_tensors")
482
+
483
+ def preprocess_progress(current, total, desc):
484
+ _log(f" {desc} ({current}/{total})")
485
+
486
+ result = preprocess_audio(
487
+ audio_dir=audio_dir,
488
+ output_dir=preprocessed_dir,
489
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
490
+ device="cpu",
491
+ variant="turbo",
492
+ max_duration=float(MAX_AUDIO_DURATION),
493
+ progress_callback=preprocess_progress,
494
+ cancel_check=lambda: False,
495
+ )
496
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
497
+
498
+ processed = result.get("processed", 0)
499
+ failed = result.get("failed", 0)
500
+ total = result.get("total", 0)
501
+ _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
502
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
503
+
504
+ if processed == 0:
505
+ _log("[FAIL] No files preprocessed successfully. Cannot train.")
506
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False)
507
+ return
508
+
509
+ _gc.collect()
510
+
511
+ # -- Phase 2: Training --
512
+ _log("[Step 2/2] Training LoRA...")
513
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
514
+
515
+ for msg in train_lora_generator(
516
+ dataset_dir=preprocessed_dir,
517
+ output_dir=output_dir,
518
+ checkpoint_dir=ACE_CHECKPOINT_DIR,
519
+ epochs=epochs,
520
+ lr=lr,
521
+ rank=rank,
522
+ alpha=rank,
523
+ dropout=0.0,
524
+ batch_size=1,
525
+ gradient_accumulation_steps=4,
526
+ warmup_steps=100,
527
+ weight_decay=0.01,
528
+ max_grad_norm=1.0,
529
+ save_every_n_epochs=max(1, epochs // 2),
530
+ seed=42,
531
+ variant="turbo",
532
+ device="cpu",
533
+ log_every=5,
534
+ ):
535
+ # Timeout check
536
+ elapsed = time.time() - train_start
537
+ if elapsed > MAX_TRAINING_TIME:
538
+ _log(f"[WARN] Training timed out after {int(elapsed)}s")
539
+ cancel_training()
540
+ break
541
+
542
+ _log(msg)
543
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
544
+
545
+ if msg.strip() == "[DONE]":
546
+ break
547
+
548
+ _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
549
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
550
+
551
+ except Exception as exc:
552
+ _log(f"[FAIL] Training error: {exc}")
553
+ import traceback
554
+ _log(traceback.format_exc())
555
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False)
556
+
557
+ finally:
558
+ # Always restart ace-server
559
+ _log("[INFO] Restarting ace-server...")
560
+ yield _log_text(), gr.update(visible=False), gr.update(visible=True)
561
+ _gc.collect()
562
+ ok = _start_ace_server()
563
+ if ok:
564
+ _log("[OK] ace-server restarted successfully")
565
+ else:
566
+ _log("[WARN] ace-server may not have restarted -- check logs")
567
+ yield _log_text(), gr.update(visible=True), gr.update(visible=False)
568
+
569
+ # -- Cancel handler --
570
+ def _on_cancel():
571
+ cancel_training()
572
+ logger.info("Cancel requested by user")
573
+ return "Cancelling after current epoch... please wait"
574
+
575
+ # -- Check log handler --
576
+ def _check_log():
577
+ if _train_log_lines:
578
+ return "\n".join(_train_log_lines)
579
+ return "No training log available."
580
+
581
+ # -- Build LM model choices --
582
+ def _lm_model_choices():
583
+ models = _scan_lm_models()
584
+ choices = ["Default"]
585
+ choices.extend(models)
586
+ return choices
 
 
 
 
 
 
 
 
 
 
 
 
587
 
588
  # -- Build UI --
589
  CSS = """
 
638
  value=8, step=1, scale=1,
639
  )
640
  seed = gr.Number(label="Seed (-1=random)", value=-1, scale=1)
641
+
642
+ with gr.Row(elem_classes="compact-row"):
643
+ lora_select = gr.Dropdown(
644
+ label="LoRA", choices=_list_lora_choices(),
645
+ value="None (no LoRA)", scale=1,
646
+ )
647
+ lm_model_select = gr.Dropdown(
648
+ label="LM Model", choices=_lm_model_choices(),
649
+ value="Default", scale=1,
650
+ )
651
 
652
  with gr.Row(elem_classes="compact-row"):
653
  gen_btn = gr.Button("Generate Music", variant="primary", scale=2)
 
656
  gen_btn.click(
657
  fn=generate_music,
658
  inputs=[caption, lyrics, instrumental, bpm, duration,
659
+ seed, steps, lora_select, lm_model_select],
660
  outputs=[audio_out, status],
661
  api_name="generate",
662
  )
 
674
  with gr.Tab("Train LoRA"):
675
  gr.Markdown(
676
  "### LoRA Training\n"
677
+ "Fine-tune ACE-Step on your audio. "
678
+ "CPU training is slow -- ace-server stops during training."
679
  )
680
 
681
  with gr.Row(elem_classes="compact-row"):
 
687
  )
688
  with gr.Column(scale=1):
689
  lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
690
+ train_epochs = gr.Slider(
691
+ label="Epochs", minimum=1, maximum=10,
692
+ value=3, step=1,
693
+ )
694
+ train_lr = gr.Number(label="Learning Rate", value=1e-4)
695
+ train_rank = gr.Slider(
696
+ label="Rank (r)", minimum=1, maximum=64,
697
+ value=16, step=1,
698
+ )
699
 
700
  with gr.Row(elem_classes="compact-row"):
701
  train_btn = gr.Button("Train", variant="primary", scale=2)
702
+ cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
703
  log_btn = gr.Button("Check Log", scale=1)
704
+
705
  train_log = gr.Textbox(
706
  label="Training Log",
707
  interactive=False,
 
709
  elem_classes="status-box",
710
  )
711
 
712
+ # Training generator -- yields (log, train_btn, cancel_btn)
713
+ train_event = train_btn.click(
714
+ train_lora_ui,
715
+ inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank],
716
+ outputs=[train_log, train_btn, cancel_btn],
717
  api_name="train_lora",
718
+ concurrency_limit=1,
719
  )
720
+
721
+ # After training completes, restore buttons and refresh LoRA dropdown
722
+ def _post_training():
723
+ return (
724
+ gr.update(visible=True),
725
+ gr.update(visible=False),
726
+ gr.update(choices=_list_lora_choices()),
727
+ )
728
+
729
+ train_event.then(
730
+ _post_training,
731
+ outputs=[train_btn, cancel_btn, lora_select],
732
+ )
733
+
734
+ # Cancel: set the flag, update status
735
+ cancel_btn.click(
736
+ _on_cancel,
737
+ outputs=[train_log],
738
+ )
739
+
740
+ # Check log: show last training output
741
  log_btn.click(
742
+ _check_log,
 
743
  outputs=[train_log],
744
+ api_name="check_log",
745
  )
746
 
747
  demo.launch(
train_engine.py ADDED
@@ -0,0 +1,1353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Standalone ACE-Step CPU LoRA Training Engine.
3
+
4
+ Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
5
+ module. No external Side-Step dependency required.
6
+
7
+ Exports:
8
+ preprocess_audio() - 2-pass sequential preprocessing
9
+ train_lora_generator() - Generator-based LoRA training loop
10
+ cancel_training() - Set the cancel flag
11
+ get_trained_loras() - List saved adapters
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import gc
17
+ import json
18
+ import logging
19
+ import math
20
+ import os
21
+ import random
22
+ import sys
23
+ import time
24
+ import types
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.optim import AdamW
33
+ from torch.optim.lr_scheduler import (
34
+ CosineAnnealingLR,
35
+ LinearLR,
36
+ SequentialLR,
37
+ )
38
+ from torch.utils.data import DataLoader, Dataset
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+ # ---------------------------------------------------------------------------
43
+ # Configurable caps (edit these at the top of the file)
44
+ # ---------------------------------------------------------------------------
45
+
46
+ MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file
47
+ MAX_TRAINING_TIME = 28800 # 8 hours hard timeout
48
+ TARGET_SR = 48000
49
+ AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"})
50
+
51
+ # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32
52
+ CPU_DTYPE = torch.float32
53
+
54
+ import threading
55
+ _training_cancel = threading.Event()
56
+
57
+
58
+ def cancel_training() -> None:
59
+ _training_cancel.set()
60
+
61
+
62
+ # ============================================================================
63
+ # CONFIGS
64
+ # ============================================================================
65
+
66
+ @dataclass
67
+ class LoRAConfig:
68
+ r: int = 64
69
+ alpha: int = 128
70
+ dropout: float = 0.1
71
+ target_modules: List[str] = field(default_factory=lambda: [
72
+ "q_proj", "k_proj", "v_proj", "o_proj",
73
+ ])
74
+ bias: str = "none"
75
+ attention_type: str = "both"
76
+ target_mlp: bool = True
77
+
78
+
79
+ # ============================================================================
80
+ # TIMESTEP SAMPLING & CFG DROPOUT
81
+ # ============================================================================
82
+
83
+ def sample_timesteps(
84
+ batch_size: int,
85
+ device: torch.device,
86
+ dtype: torch.dtype,
87
+ timestep_mu: float = -0.4,
88
+ timestep_sigma: float = 1.0,
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ t = torch.sigmoid(
91
+ torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
92
+ )
93
+ r = torch.sigmoid(
94
+ torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu
95
+ )
96
+ t, r = torch.maximum(t, r), torch.minimum(t, r)
97
+ # use_meanflow=False forces r=t (ACE-Step convention)
98
+ return t, t
99
+
100
+
101
+ def apply_cfg_dropout(
102
+ encoder_hidden_states: torch.Tensor,
103
+ null_condition_emb: torch.Tensor,
104
+ cfg_ratio: float = 0.15,
105
+ ) -> torch.Tensor:
106
+ bsz = encoder_hidden_states.shape[0]
107
+ device = encoder_hidden_states.device
108
+ dtype = encoder_hidden_states.dtype
109
+ mask = torch.where(
110
+ torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio,
111
+ torch.zeros(size=(bsz,), device=device, dtype=dtype),
112
+ torch.ones(size=(bsz,), device=device, dtype=dtype),
113
+ ).view(-1, 1, 1)
114
+ return torch.where(
115
+ mask > 0,
116
+ encoder_hidden_states,
117
+ null_condition_emb.expand_as(encoder_hidden_states),
118
+ )
119
+
120
+
121
+ # ============================================================================
122
+ # OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param)
123
+ # ============================================================================
124
+
125
+ def build_optimizer(
126
+ params, lr: float = 1e-4, weight_decay: float = 0.01,
127
+ ) -> torch.optim.Optimizer:
128
+ try:
129
+ from transformers.optimization import Adafactor
130
+ logger.info("Using Adafactor optimizer (minimal state memory)")
131
+ return Adafactor(
132
+ params, lr=lr, weight_decay=weight_decay,
133
+ scale_parameter=False, relative_step=False,
134
+ )
135
+ except ImportError:
136
+ logger.warning("transformers not installed, falling back to AdamW")
137
+ return AdamW(params, lr=lr, weight_decay=weight_decay)
138
+
139
+
140
+ def build_scheduler(
141
+ optimizer, total_steps: int, warmup_steps: int, lr: float,
142
+ ):
143
+ _max_warmup = max(1, total_steps // 10)
144
+ if warmup_steps > _max_warmup:
145
+ warmup_steps = _max_warmup
146
+
147
+ warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps)
148
+ remaining = max(1, total_steps - warmup_steps)
149
+ main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01)
150
+ return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps])
151
+
152
+
153
+ # ============================================================================
154
+ # DATASET
155
+ # ============================================================================
156
+
157
+ def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]:
158
+ max_t = max(s["target_latents"].shape[0] for s in batch)
159
+ max_e = max(s["encoder_hidden_states"].shape[0] for s in batch)
160
+
161
+ def pad(t, max_len, dim=0):
162
+ diff = max_len - t.shape[dim]
163
+ if diff <= 0:
164
+ return t
165
+ shape = list(t.shape)
166
+ shape[dim] = diff
167
+ return torch.cat([t, t.new_zeros(*shape)], dim=dim)
168
+
169
+ return {
170
+ "target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]),
171
+ "attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]),
172
+ "encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]),
173
+ "encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]),
174
+ "context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]),
175
+ }
176
+
177
+
178
+ class TensorDataset(Dataset):
179
+ _REQUIRED = frozenset([
180
+ "target_latents", "attention_mask", "encoder_hidden_states",
181
+ "encoder_attention_mask", "context_latents",
182
+ ])
183
+
184
+ def __init__(self, tensor_dir: str):
185
+ self.paths: List[str] = []
186
+ for f in sorted(os.listdir(tensor_dir)):
187
+ if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json":
188
+ self.paths.append(str(Path(tensor_dir) / f))
189
+
190
+ def __len__(self) -> int:
191
+ return len(self.paths)
192
+
193
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
194
+ data = torch.load(self.paths[idx], map_location="cpu", weights_only=True)
195
+ missing = self._REQUIRED - data.keys()
196
+ if missing:
197
+ raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}")
198
+ for k in ("target_latents", "encoder_hidden_states", "context_latents"):
199
+ t = data[k]
200
+ if torch.isnan(t).any() or torch.isinf(t).any():
201
+ t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
202
+ return {k: data[k] for k in self._REQUIRED}
203
+
204
+
205
+ # ============================================================================
206
+ # GRADIENT CHECKPOINTING
207
+ # ============================================================================
208
+
209
+ def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]:
210
+ for attr in ("layers", "blocks", "transformer_blocks"):
211
+ c = getattr(decoder, attr, None)
212
+ if isinstance(c, nn.ModuleList) and len(c) > 0:
213
+ return c
214
+ for child in decoder.children():
215
+ for attr in ("layers", "blocks", "transformer_blocks"):
216
+ c = getattr(child, attr, None)
217
+ if isinstance(c, nn.ModuleList) and len(c) > 0:
218
+ return c
219
+ return None
220
+
221
+
222
+ def enable_gradient_checkpointing(decoder: nn.Module) -> bool:
223
+ """Enable gradient checkpointing on the decoder to save memory."""
224
+ enabled = False
225
+
226
+ # Walk wrapper chain
227
+ stack = [decoder]
228
+ visited = set()
229
+ while stack:
230
+ mod = stack.pop()
231
+ if not isinstance(mod, nn.Module):
232
+ continue
233
+ mid = id(mod)
234
+ if mid in visited:
235
+ continue
236
+ visited.add(mid)
237
+
238
+ if hasattr(mod, "gradient_checkpointing_enable"):
239
+ try:
240
+ mod.gradient_checkpointing_enable()
241
+ enabled = True
242
+ except Exception:
243
+ pass
244
+ elif hasattr(mod, "gradient_checkpointing"):
245
+ try:
246
+ mod.gradient_checkpointing = True
247
+ enabled = True
248
+ except Exception:
249
+ pass
250
+
251
+ if hasattr(mod, "enable_input_require_grads"):
252
+ try:
253
+ mod.enable_input_require_grads()
254
+ except Exception:
255
+ pass
256
+
257
+ cfg = getattr(mod, "config", None)
258
+ if cfg is not None and hasattr(cfg, "use_cache"):
259
+ try:
260
+ cfg.use_cache = False
261
+ except Exception:
262
+ pass
263
+
264
+ for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
265
+ child = getattr(mod, a, None)
266
+ if isinstance(child, nn.Module):
267
+ stack.append(child)
268
+
269
+ return enabled
270
+
271
+
272
+ def force_disable_cache(decoder: nn.Module) -> None:
273
+ stack = [decoder]
274
+ visited = set()
275
+ while stack:
276
+ mod = stack.pop()
277
+ if not isinstance(mod, nn.Module):
278
+ continue
279
+ mid = id(mod)
280
+ if mid in visited:
281
+ continue
282
+ visited.add(mid)
283
+ cfg = getattr(mod, "config", None)
284
+ if cfg is not None and hasattr(cfg, "use_cache"):
285
+ try:
286
+ cfg.use_cache = False
287
+ except Exception:
288
+ pass
289
+ for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
290
+ child = getattr(mod, a, None)
291
+ if isinstance(child, nn.Module):
292
+ stack.append(child)
293
+
294
+
295
+ # ============================================================================
296
+ # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
297
+ # ============================================================================
298
+
299
+ def _unwrap_decoder(model):
300
+ decoder = model.decoder if hasattr(model, "decoder") else model
301
+ while hasattr(decoder, "_forward_module"):
302
+ decoder = decoder._forward_module
303
+ if hasattr(decoder, "base_model"):
304
+ bm = decoder.base_model
305
+ decoder = bm.model if hasattr(bm, "model") else bm
306
+ if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
307
+ decoder = decoder.model
308
+ return decoder
309
+
310
+
311
+ def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]:
312
+ from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType
313
+
314
+ decoder = _unwrap_decoder(model)
315
+ model.decoder = decoder
316
+
317
+ # Guard enable_input_require_grads for DiT (no get_input_embeddings)
318
+ if hasattr(decoder, "enable_input_require_grads"):
319
+ orig = decoder.enable_input_require_grads
320
+
321
+ def _safe(self):
322
+ try:
323
+ return orig()
324
+ except NotImplementedError:
325
+ return None
326
+
327
+ decoder.enable_input_require_grads = types.MethodType(_safe, decoder)
328
+
329
+ if hasattr(decoder, "is_gradient_checkpointing"):
330
+ try:
331
+ decoder.is_gradient_checkpointing = False
332
+ except Exception:
333
+ pass
334
+
335
+ peft_cfg = PeftLoraConfig(
336
+ r=lora_cfg.r,
337
+ lora_alpha=lora_cfg.alpha,
338
+ lora_dropout=lora_cfg.dropout,
339
+ target_modules=lora_cfg.target_modules,
340
+ bias=lora_cfg.bias,
341
+ task_type=TaskType.FEATURE_EXTRACTION,
342
+ )
343
+
344
+ model.decoder = get_peft_model(decoder, peft_cfg)
345
+
346
+ for name, param in model.named_parameters():
347
+ if "lora_" not in name:
348
+ param.requires_grad = False
349
+
350
+ total = sum(p.numel() for p in model.parameters())
351
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
352
+
353
+ return model, {
354
+ "total_params": total,
355
+ "trainable_params": trainable,
356
+ "trainable_ratio": trainable / total if total > 0 else 0,
357
+ }
358
+
359
+
360
+ def save_lora_adapter(model, output_dir: str) -> None:
361
+ os.makedirs(output_dir, exist_ok=True)
362
+ decoder = model.decoder if hasattr(model, "decoder") else model
363
+ while hasattr(decoder, "_forward_module"):
364
+ decoder = decoder._forward_module
365
+
366
+ if hasattr(decoder, "save_pretrained"):
367
+ decoder.save_pretrained(output_dir)
368
+ # Scrub base_model path for portability
369
+ cfg_path = os.path.join(output_dir, "adapter_config.json")
370
+ if os.path.isfile(cfg_path):
371
+ try:
372
+ with open(cfg_path, "r") as f:
373
+ cfg = json.load(f)
374
+ if cfg.get("base_model_name_or_path"):
375
+ cfg["base_model_name_or_path"] = ""
376
+ with open(cfg_path, "w") as f:
377
+ json.dump(cfg, f, indent=2)
378
+ except Exception:
379
+ pass
380
+ logger.info("LoRA adapter saved to %s", output_dir)
381
+ else:
382
+ # Fallback: manual extraction
383
+ state = {}
384
+ for name, param in decoder.named_parameters():
385
+ if "lora_" in name:
386
+ state[name] = param.data.clone()
387
+ if state:
388
+ try:
389
+ from safetensors.torch import save_file
390
+ save_file(state, str(Path(output_dir) / "adapter_model.safetensors"))
391
+ except ImportError:
392
+ torch.save(state, str(Path(output_dir) / "lora_weights.pt"))
393
+ logger.info("LoRA adapter saved (fallback) to %s", output_dir)
394
+
395
+
396
+ # ============================================================================
397
+ # MODEL LOADING (FA2 -> SDPA -> eager fallback)
398
+ # ============================================================================
399
+
400
+ _VARIANT_DIR = {
401
+ "turbo": "acestep-v15-turbo",
402
+ "base": "acestep-v15-base",
403
+ "sft": "acestep-v15-sft",
404
+ }
405
+
406
+
407
+ def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path:
408
+ base = Path(checkpoint_dir).resolve()
409
+ subdir = _VARIANT_DIR.get(variant)
410
+ if subdir:
411
+ p = (Path(checkpoint_dir) / subdir).resolve()
412
+ if p.is_dir():
413
+ return p
414
+ p = (Path(checkpoint_dir) / variant).resolve()
415
+ if p.is_dir():
416
+ return p
417
+ raise FileNotFoundError(
418
+ f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} "
419
+ f"and {variant!r} under {checkpoint_dir}"
420
+ )
421
+
422
+
423
+ def _ensure_acestep_imports():
424
+ """Register stub modules so AutoModel can load ACE-Step checkpoints."""
425
+ for name in (
426
+ "acestep", "acestep.models", "acestep.models.common",
427
+ "acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft",
428
+ ):
429
+ if name not in sys.modules:
430
+ stub = types.ModuleType(name)
431
+ stub.__path__ = []
432
+ sys.modules[name] = stub
433
+
434
+ # Try to load real modules from adjacent ACE-Step checkout
435
+ for name in (
436
+ "acestep.models.common.configuration_acestep_v15",
437
+ "acestep.models.common.apg_guidance",
438
+ ):
439
+ if name not in sys.modules:
440
+ sys.modules[name] = types.ModuleType(name)
441
+
442
+
443
+ def _attn_candidates(device: str) -> List[str]:
444
+ """FA2 -> SDPA -> eager, filtered by availability."""
445
+ candidates = []
446
+ if device.startswith("cuda"):
447
+ try:
448
+ import flash_attn # noqa: F401
449
+ dev_idx = int(device.split(":")[1]) if ":" in device else 0
450
+ props = torch.cuda.get_device_properties(dev_idx)
451
+ if props.major >= 8:
452
+ candidates.append("flash_attention_2")
453
+ except (ImportError, Exception):
454
+ pass
455
+ candidates.extend(["sdpa", "eager"])
456
+ return candidates
457
+
458
+
459
+ def load_model_for_training(
460
+ checkpoint_dir: str, variant: str = "base", device: str = "cpu",
461
+ ) -> Any:
462
+ from transformers import AutoModel
463
+
464
+ model_dir = _resolve_model_dir(checkpoint_dir, variant)
465
+ # CPU always uses float32
466
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
467
+
468
+ _ensure_acestep_imports()
469
+
470
+ candidates = _attn_candidates(device)
471
+ model = None
472
+ last_err = None
473
+
474
+ for idx, attn in enumerate(candidates):
475
+ try:
476
+ load_kwargs = dict(
477
+ trust_remote_code=True,
478
+ attn_implementation=attn,
479
+ torch_dtype=dtype,
480
+ low_cpu_mem_usage=False,
481
+ )
482
+ if device != "cpu":
483
+ load_kwargs["device_map"] = {"": device}
484
+ model = AutoModel.from_pretrained(str(model_dir), **load_kwargs)
485
+ logger.info("Model loaded with attn_implementation=%s", attn)
486
+ break
487
+ except Exception as exc:
488
+ err_text = str(exc)
489
+ if "packages that were not found" in err_text or "No module named" in err_text:
490
+ raise RuntimeError(
491
+ f"Model files in {model_dir} require a missing Python package.\n"
492
+ f" Original error: {err_text}"
493
+ ) from exc
494
+ last_err = exc
495
+ logger.warning("attn backend '%s' failed: %s", attn, exc)
496
+
497
+ if model is None:
498
+ raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err
499
+
500
+ for param in model.parameters():
501
+ param.requires_grad = False
502
+ model.eval()
503
+ return model
504
+
505
+
506
+ def load_vae(checkpoint_dir: str, device: str = "cpu"):
507
+ from diffusers.models import AutoencoderOobleck
508
+
509
+ vae_path = Path(checkpoint_dir) / "vae"
510
+ if not vae_path.is_dir():
511
+ raise FileNotFoundError(f"VAE directory not found: {vae_path}")
512
+
513
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
514
+ vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype)
515
+ vae = vae.to(device=device)
516
+ vae.eval()
517
+ return vae
518
+
519
+
520
+ def load_text_encoder(checkpoint_dir: str, device: str = "cpu"):
521
+ from transformers import AutoModel, AutoTokenizer
522
+
523
+ text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B"
524
+ if not text_path.is_dir():
525
+ raise FileNotFoundError(f"Text encoder not found: {text_path}")
526
+
527
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
528
+ tokenizer = AutoTokenizer.from_pretrained(str(text_path))
529
+ encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype)
530
+ encoder = encoder.to(device=device)
531
+ encoder.eval()
532
+ return tokenizer, encoder
533
+
534
+
535
+ def load_silence_latent(
536
+ checkpoint_dir: str, device: str = "cpu", variant: str = "base",
537
+ ) -> torch.Tensor:
538
+ ckpt = Path(checkpoint_dir)
539
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
540
+
541
+ candidates = [ckpt / "silence_latent.pt"]
542
+ subdir = _VARIANT_DIR.get(variant)
543
+ if subdir:
544
+ candidates.append(ckpt / subdir / "silence_latent.pt")
545
+ for sd in _VARIANT_DIR.values():
546
+ candidates.append(ckpt / sd / "silence_latent.pt")
547
+
548
+ for c in candidates:
549
+ if c.is_file():
550
+ sl = torch.load(str(c), weights_only=True).transpose(1, 2)
551
+ return sl.to(device=device, dtype=dtype)
552
+
553
+ raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}")
554
+
555
+
556
+ def unload_models(*models) -> None:
557
+ for obj in models:
558
+ if obj is None:
559
+ continue
560
+ if hasattr(obj, "to"):
561
+ try:
562
+ obj.to("cpu")
563
+ except Exception:
564
+ pass
565
+ del obj
566
+ gc.collect()
567
+
568
+
569
+ # ============================================================================
570
+ # AUDIO LOADING
571
+ # ============================================================================
572
+
573
+ def load_audio_stereo(
574
+ audio_path: str, target_sr: int, max_duration: float,
575
+ ) -> Tuple[torch.Tensor, int]:
576
+ import numpy as np
577
+
578
+ try:
579
+ import soundfile as sf
580
+ data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
581
+ audio_np = np.ascontiguousarray(data.T)
582
+ sr = int(sr)
583
+ if sr != target_sr:
584
+ import librosa
585
+ audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1)
586
+ sr = target_sr
587
+ audio = torch.from_numpy(np.ascontiguousarray(audio_np))
588
+ except Exception:
589
+ import torchaudio
590
+ audio, sr = torchaudio.load(audio_path)
591
+ sr = int(sr)
592
+ if sr != target_sr:
593
+ audio = torchaudio.transforms.Resample(sr, target_sr)(audio)
594
+ sr = target_sr
595
+
596
+ if audio.shape[0] == 1:
597
+ audio = audio.repeat(2, 1)
598
+ elif audio.shape[0] > 2:
599
+ audio = audio[:2, :]
600
+
601
+ max_samples = int(max_duration * target_sr)
602
+ if audio.shape[1] > max_samples:
603
+ audio = audio[:, :max_samples]
604
+
605
+ return audio, sr
606
+
607
+
608
+ # ============================================================================
609
+ # TEXT / LYRICS ENCODING
610
+ # ============================================================================
611
+
612
+ def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype):
613
+ inputs = tokenizer(
614
+ text_prompt, padding="max_length", max_length=256,
615
+ truncation=True, return_tensors="pt",
616
+ )
617
+ ids = inputs.input_ids.to(device)
618
+ mask = inputs.attention_mask.to(device).to(dtype)
619
+
620
+ enc_dev = next(text_encoder.parameters()).device
621
+ if ids.device != enc_dev:
622
+ ids = ids.to(enc_dev)
623
+ mask = mask.to(enc_dev)
624
+
625
+ with torch.no_grad():
626
+ hs = text_encoder(ids).last_hidden_state.to(dtype)
627
+ return hs, mask
628
+
629
+
630
+ def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype):
631
+ inputs = tokenizer(
632
+ lyrics, padding="max_length", max_length=512,
633
+ truncation=True, return_tensors="pt",
634
+ )
635
+ ids = inputs.input_ids.to(device)
636
+ mask = inputs.attention_mask.to(device).to(dtype)
637
+
638
+ enc_dev = next(text_encoder.parameters()).device
639
+ if ids.device != enc_dev:
640
+ ids = ids.to(enc_dev)
641
+ mask = mask.to(enc_dev)
642
+
643
+ with torch.no_grad():
644
+ hs = text_encoder.embed_tokens(ids).to(dtype)
645
+ return hs, mask
646
+
647
+
648
+ # ============================================================================
649
+ # VAE TILED ENCODING
650
+ # ============================================================================
651
+
652
+ def tiled_vae_encode(
653
+ vae, audio: torch.Tensor, dtype: torch.dtype,
654
+ chunk_size: Optional[int] = None, overlap: int = 96000,
655
+ ) -> torch.Tensor:
656
+ vae_device = next(vae.parameters()).device
657
+ vae_dtype = vae.dtype
658
+
659
+ if chunk_size is None:
660
+ chunk_size = TARGET_SR * 30
661
+
662
+ B, C, S = audio.shape
663
+
664
+ if S <= chunk_size:
665
+ vae_input = audio.to(vae_device, dtype=vae_dtype)
666
+ with torch.inference_mode():
667
+ latents = vae.encode(vae_input).latent_dist.sample()
668
+ return latents.transpose(1, 2).to(dtype)
669
+
670
+ stride = chunk_size - 2 * overlap
671
+ if stride <= 0:
672
+ raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})")
673
+
674
+ num_steps = math.ceil(S / stride)
675
+ ds_factor = None
676
+ write_pos = 0
677
+ final = None
678
+
679
+ for i in range(num_steps):
680
+ core_start = i * stride
681
+ core_end = min(core_start + stride, S)
682
+ win_start = max(0, core_start - overlap)
683
+ win_end = min(S, core_end + overlap)
684
+
685
+ chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype)
686
+ with torch.inference_mode():
687
+ lat = vae.encode(chunk).latent_dist.sample()
688
+
689
+ if ds_factor is None:
690
+ ds_factor = chunk.shape[-1] / lat.shape[-1]
691
+ total_len = int(round(S / ds_factor))
692
+ final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu")
693
+
694
+ trim_start = int(round((core_start - win_start) / ds_factor))
695
+ trim_end = int(round((win_end - core_end) / ds_factor))
696
+ end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1]
697
+ core = lat[:, :, trim_start:end_idx]
698
+ core_len = core.shape[-1]
699
+ final[:, :, write_pos:write_pos + core_len] = core.cpu()
700
+ write_pos += core_len
701
+ del chunk, lat, core
702
+
703
+ final = final[:, :, :write_pos]
704
+ return final.transpose(1, 2).to(dtype)
705
+
706
+
707
+ # ============================================================================
708
+ # ENCODER / CONTEXT HELPERS
709
+ # ============================================================================
710
+
711
+ def run_encoder(
712
+ model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype,
713
+ ):
714
+ refer = torch.zeros(1, 1, 64, device=device, dtype=dtype)
715
+ order_mask = torch.zeros(1, device=device, dtype=torch.long)
716
+
717
+ with torch.no_grad():
718
+ enc_hs, enc_mask = model.encoder(
719
+ text_hidden_states=text_hs,
720
+ text_attention_mask=text_mask,
721
+ lyric_hidden_states=lyric_hs,
722
+ lyric_attention_mask=lyric_mask,
723
+ refer_audio_acoustic_hidden_states_packed=refer,
724
+ refer_audio_order_mask=order_mask,
725
+ )
726
+ return enc_hs, enc_mask
727
+
728
+
729
+ def build_context_latents(silence_latent, latent_length: int, device, dtype):
730
+ src = silence_latent[:, :latent_length, :].to(dtype)
731
+ if src.shape[0] < 1:
732
+ src = src.expand(1, -1, -1)
733
+ if src.shape[1] < latent_length:
734
+ pad_len = latent_length - src.shape[1]
735
+ src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1)
736
+ elif src.shape[1] > latent_length:
737
+ src = src[:, :latent_length, :]
738
+ masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
739
+ return torch.cat([src, masks], dim=-1)
740
+
741
+
742
+ # ============================================================================
743
+ # AUDIO DISCOVERY
744
+ # ============================================================================
745
+
746
+ def _discover_audio_files(audio_dir: str) -> List[Path]:
747
+ files = []
748
+ for root, _, names in os.walk(audio_dir):
749
+ for name in sorted(names):
750
+ if Path(name).suffix.lower() in AUDIO_EXTENSIONS:
751
+ files.append(Path(root) / name)
752
+ return files
753
+
754
+
755
+ def _detect_max_duration(files: List[Path]) -> float:
756
+ """Return the longest audio file duration (capped at MAX_AUDIO_DURATION)."""
757
+ max_dur = 0.0
758
+ try:
759
+ import soundfile as sf
760
+ for f in files[:50]:
761
+ try:
762
+ info = sf.info(str(f))
763
+ max_dur = max(max_dur, info.duration)
764
+ except Exception:
765
+ pass
766
+ except ImportError:
767
+ pass
768
+ return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
769
+
770
+
771
+ # ============================================================================
772
+ # PREPROCESSING (2-pass sequential)
773
+ # ============================================================================
774
+
775
+ def preprocess_audio(
776
+ audio_dir: str,
777
+ output_dir: str,
778
+ checkpoint_dir: str,
779
+ device: str = "cpu",
780
+ variant: str = "base",
781
+ max_duration: float = 0,
782
+ progress_callback: Optional[Callable] = None,
783
+ cancel_check: Optional[Callable] = None,
784
+ ) -> Dict[str, Any]:
785
+ """2-pass sequential preprocessing.
786
+
787
+ Pass 1: Load VAE + text encoder, encode audio + text, save intermediates.
788
+ Pass 2: Load DIT model, run encoder, build context, save final .pt files.
789
+ """
790
+ out = Path(output_dir)
791
+ out.mkdir(parents=True, exist_ok=True)
792
+
793
+ # Clean orphaned staging files
794
+ for orphan in out.glob("*.__writing__"):
795
+ try:
796
+ orphan.unlink()
797
+ except OSError:
798
+ pass
799
+
800
+ audio_files = _discover_audio_files(audio_dir)
801
+ if not audio_files:
802
+ return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)}
803
+
804
+ total = len(audio_files)
805
+
806
+ if max_duration <= 0:
807
+ max_duration = _detect_max_duration(audio_files)
808
+
809
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
810
+
811
+ # ---- Pass 1: VAE + Text Encoder ----
812
+ logger.info("Pass 1/2: Loading VAE + Text Encoder...")
813
+ vae = load_vae(checkpoint_dir, device)
814
+ tokenizer, text_enc = load_text_encoder(checkpoint_dir, device)
815
+ silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant)
816
+
817
+ intermediates: List[Path] = []
818
+ p1_failed = 0
819
+
820
+ try:
821
+ for i, af in enumerate(audio_files):
822
+ if cancel_check and cancel_check():
823
+ break
824
+
825
+ stem = af.stem
826
+ final_pt = out / f"{stem}.pt"
827
+ if final_pt.exists():
828
+ continue
829
+
830
+ try:
831
+ audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration)
832
+ audio = audio.unsqueeze(0).to(device=device, dtype=vae.dtype)
833
+
834
+ with torch.no_grad():
835
+ target_latents = tiled_vae_encode(vae, audio, dtype)
836
+ del audio
837
+
838
+ if torch.isnan(target_latents).any() or torch.isinf(target_latents).any():
839
+ p1_failed += 1
840
+ del target_latents
841
+ continue
842
+
843
+ lat_len = target_latents.shape[1]
844
+ att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
845
+
846
+ caption = af.stem
847
+ lyrics = "[Instrumental]"
848
+ text_prompt = caption
849
+
850
+ with torch.no_grad():
851
+ text_hs, text_mask = encode_text(text_enc, tokenizer, text_prompt, device, dtype)
852
+ lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype)
853
+
854
+ has_bad = any(
855
+ torch.isnan(t).any() or torch.isinf(t).any()
856
+ for t in [text_hs, lyric_hs]
857
+ )
858
+ if has_bad:
859
+ p1_failed += 1
860
+ del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
861
+ continue
862
+
863
+ tmp_path = out / f"{stem}.tmp.pt"
864
+ torch.save({
865
+ "target_latents": target_latents.squeeze(0).cpu(),
866
+ "attention_mask": att_mask.squeeze(0).cpu(),
867
+ "text_hidden_states": text_hs.cpu(),
868
+ "text_attention_mask": text_mask.cpu(),
869
+ "lyric_hidden_states": lyric_hs.cpu(),
870
+ "lyric_attention_mask": lyric_mask.cpu(),
871
+ "silence_latent": silence_lat.cpu(),
872
+ "latent_length": lat_len,
873
+ "metadata": {
874
+ "audio_path": str(af),
875
+ "filename": af.name,
876
+ "caption": caption,
877
+ "lyrics": lyrics,
878
+ },
879
+ }, tmp_path)
880
+
881
+ del target_latents, att_mask, text_hs, text_mask, lyric_hs, lyric_mask
882
+ intermediates.append(tmp_path)
883
+
884
+ if progress_callback:
885
+ progress_callback(i + 1, total, f"[Pass 1] {af.name}")
886
+
887
+ except Exception as exc:
888
+ p1_failed += 1
889
+ logger.error("Pass 1 FAIL %s: %s", af.name, exc)
890
+ finally:
891
+ logger.info("Unloading VAE + Text Encoder...")
892
+ unload_models(vae, text_enc, tokenizer, silence_lat)
893
+
894
+ # ---- Pass 2: DIT Encoder ----
895
+ if not intermediates:
896
+ return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)}
897
+
898
+ logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant)
899
+ model = load_model_for_training(checkpoint_dir, variant, device)
900
+
901
+ processed = 0
902
+ p2_failed = 0
903
+ p2_total = len(intermediates)
904
+
905
+ try:
906
+ for i, tmp_path in enumerate(intermediates):
907
+ if cancel_check and cancel_check():
908
+ break
909
+
910
+ try:
911
+ data = torch.load(str(tmp_path), weights_only=True)
912
+ m_device = next(model.parameters()).device
913
+ m_dtype = next(model.parameters()).dtype
914
+
915
+ text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype)
916
+ text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype)
917
+ lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype)
918
+ lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype)
919
+ silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype)
920
+ lat_len = data["latent_length"]
921
+
922
+ enc_hs, enc_mask = run_encoder(
923
+ model, text_hs, text_mask, lyric_hs, lyric_mask,
924
+ str(m_device), m_dtype,
925
+ )
926
+ del text_hs, text_mask, lyric_hs, lyric_mask
927
+
928
+ if silence_lat.dim() == 2:
929
+ silence_lat = silence_lat.unsqueeze(0)
930
+ ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype)
931
+ del silence_lat
932
+
933
+ has_bad = any(
934
+ torch.isnan(t).any() or torch.isinf(t).any()
935
+ for t in [enc_hs, ctx]
936
+ )
937
+ if has_bad:
938
+ p2_failed += 1
939
+ del enc_hs, enc_mask, ctx, data
940
+ continue
941
+
942
+ base_name = tmp_path.name.replace(".tmp.pt", ".pt")
943
+ final_path = out / base_name
944
+ staging_path = out / (base_name + ".__writing__")
945
+
946
+ torch.save({
947
+ "target_latents": data["target_latents"],
948
+ "attention_mask": data["attention_mask"],
949
+ "encoder_hidden_states": enc_hs.squeeze(0).cpu(),
950
+ "encoder_attention_mask": enc_mask.squeeze(0).cpu(),
951
+ "context_latents": ctx.squeeze(0).cpu(),
952
+ "metadata": data.get("metadata", {}),
953
+ }, staging_path)
954
+ os.replace(staging_path, final_path)
955
+
956
+ del enc_hs, enc_mask, ctx, data
957
+ tmp_path.unlink(missing_ok=True)
958
+ processed += 1
959
+
960
+ if progress_callback:
961
+ progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}")
962
+
963
+ except Exception as exc:
964
+ p2_failed += 1
965
+ logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc)
966
+ finally:
967
+ logger.info("Unloading DIT model...")
968
+ unload_models(model)
969
+
970
+ failed = p1_failed + p2_failed
971
+ return {"processed": processed, "failed": failed, "total": total, "output_dir": str(out)}
972
+
973
+
974
+ # ============================================================================
975
+ # TRAINING LOOP (generator for Gradio compatibility)
976
+ # ============================================================================
977
+
978
+ def train_lora_generator(
979
+ dataset_dir: str,
980
+ output_dir: str,
981
+ checkpoint_dir: str,
982
+ epochs: int = 1000,
983
+ lr: float = 3e-4,
984
+ rank: int = 64,
985
+ alpha: int = 128,
986
+ dropout: float = 0.1,
987
+ batch_size: int = 1,
988
+ gradient_accumulation_steps: int = 4,
989
+ warmup_steps: int = 100,
990
+ weight_decay: float = 0.01,
991
+ max_grad_norm: float = 1.0,
992
+ save_every_n_epochs: int = 50,
993
+ seed: int = 42,
994
+ variant: str = "base",
995
+ device: str = "cpu",
996
+ cfg_ratio: float = 0.15,
997
+ timestep_mu: float = -0.4,
998
+ timestep_sigma: float = 1.0,
999
+ target_modules: Optional[List[str]] = None,
1000
+ log_every: int = 10,
1001
+ resume_from: Optional[str] = None,
1002
+ ) -> Generator[str, None, None]:
1003
+ """Run LoRA training, yielding progress strings each epoch.
1004
+
1005
+ This is a generator for Gradio live-update compatibility.
1006
+ Call cancel_training() to stop after the current epoch.
1007
+ """
1008
+ _training_cancel.clear()
1009
+ train_start = time.time()
1010
+
1011
+ if target_modules is None:
1012
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
1013
+
1014
+ ds_path = Path(dataset_dir)
1015
+ if not ds_path.is_dir():
1016
+ yield f"[FAIL] Dataset directory not found: {ds_path}"
1017
+ return
1018
+
1019
+ out_path = Path(output_dir)
1020
+ out_path.mkdir(parents=True, exist_ok=True)
1021
+
1022
+ yield "[INFO] Loading model..."
1023
+
1024
+ try:
1025
+ model = load_model_for_training(checkpoint_dir, variant, device)
1026
+ except Exception as exc:
1027
+ yield f"[FAIL] Model load failed: {exc}"
1028
+ return
1029
+
1030
+ # float32 on CPU (bfloat16 deadlocks)
1031
+ dtype = CPU_DTYPE if device == "cpu" else torch.bfloat16
1032
+ model = model.to(dtype=dtype)
1033
+
1034
+ yield "[INFO] Injecting LoRA..."
1035
+
1036
+ lora_cfg = LoRAConfig(
1037
+ r=rank, alpha=alpha, dropout=dropout,
1038
+ target_modules=target_modules, bias="none",
1039
+ )
1040
+
1041
+ try:
1042
+ model, info = inject_lora(model, lora_cfg)
1043
+ except Exception as exc:
1044
+ yield f"[FAIL] LoRA injection failed: {exc}"
1045
+ unload_models(model)
1046
+ return
1047
+
1048
+ yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
1049
+
1050
+ # Gradient checkpointing + cache disable
1051
+ force_disable_cache(model.decoder)
1052
+ ckpt_ok = enable_gradient_checkpointing(model.decoder)
1053
+ force_input_grads = ckpt_ok
1054
+ if ckpt_ok:
1055
+ yield "[INFO] Gradient checkpointing enabled"
1056
+
1057
+ # Dataset
1058
+ dataset = TensorDataset(dataset_dir)
1059
+ if len(dataset) == 0:
1060
+ yield "[FAIL] No valid .pt files found in dataset directory"
1061
+ unload_models(model)
1062
+ return
1063
+
1064
+ yield f"[OK] Loaded {len(dataset)} preprocessed samples"
1065
+
1066
+ loader = DataLoader(
1067
+ dataset, batch_size=batch_size, shuffle=True,
1068
+ num_workers=0, collate_fn=_collate_batch, drop_last=False,
1069
+ )
1070
+
1071
+ # Optimizer & scheduler
1072
+ torch.manual_seed(seed)
1073
+ random.seed(seed)
1074
+
1075
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
1076
+ if not trainable_params:
1077
+ yield "[FAIL] No trainable parameters found"
1078
+ unload_models(model)
1079
+ return
1080
+
1081
+ optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay)
1082
+ steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps))
1083
+ total_steps = steps_per_epoch * epochs
1084
+ scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr)
1085
+
1086
+ yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs"
1087
+ yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}"
1088
+
1089
+ # Null condition embedding for CFG dropout
1090
+ null_cond = getattr(model, "null_condition_emb", None)
1091
+
1092
+ # Resume checkpoint
1093
+ start_epoch = 0
1094
+ global_step = 0
1095
+
1096
+ if resume_from and Path(resume_from).exists():
1097
+ try:
1098
+ yield f"[INFO] Resuming from {resume_from}"
1099
+ ckpt_dir = Path(resume_from)
1100
+ if ckpt_dir.is_file():
1101
+ ckpt_dir = ckpt_dir.parent
1102
+
1103
+ # Load adapter weights
1104
+ aw = ckpt_dir / "adapter_model.safetensors"
1105
+ if aw.exists():
1106
+ from safetensors.torch import load_file
1107
+ state = load_file(str(aw))
1108
+ decoder = model.decoder
1109
+ while hasattr(decoder, "_forward_module"):
1110
+ decoder = decoder._forward_module
1111
+ decoder.load_state_dict(state, strict=False)
1112
+
1113
+ # Load training state
1114
+ ts = ckpt_dir / "training_state.pt"
1115
+ if ts.exists():
1116
+ tstate = torch.load(str(ts), map_location=device, weights_only=True)
1117
+ start_epoch = tstate.get("epoch", 0)
1118
+ global_step = tstate.get("global_step", 0)
1119
+ if "optimizer_state_dict" in tstate:
1120
+ try:
1121
+ optimizer.load_state_dict(tstate["optimizer_state_dict"])
1122
+ except Exception:
1123
+ pass
1124
+ if "scheduler_state_dict" in tstate:
1125
+ try:
1126
+ scheduler.load_state_dict(tstate["scheduler_state_dict"])
1127
+ except Exception:
1128
+ pass
1129
+
1130
+ yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}"
1131
+ except Exception as exc:
1132
+ yield f"[WARN] Checkpoint load failed: {exc}, starting fresh"
1133
+ start_epoch = 0
1134
+ global_step = 0
1135
+
1136
+ # Training loop
1137
+ model.decoder.train()
1138
+ acc_step = 0
1139
+ acc_loss = 0.0
1140
+ optimizer.zero_grad(set_to_none=True)
1141
+
1142
+ best_loss = float("inf")
1143
+ best_epoch = 0
1144
+ consecutive_nan = 0
1145
+ MAX_NAN = 10
1146
+
1147
+ for epoch in range(start_epoch, epochs):
1148
+ # Cancel check
1149
+ if _training_cancel.is_set():
1150
+ _training_cancel.clear()
1151
+ early_path = str(out_path / "early_exit")
1152
+ model.decoder.eval()
1153
+ save_lora_adapter(model, early_path)
1154
+ model.decoder.train()
1155
+ yield f"[OK] Cancelled at epoch {epoch + 1}, saved to {early_path}"
1156
+ yield "[DONE]"
1157
+ unload_models(model)
1158
+ return
1159
+
1160
+ # Timeout check
1161
+ elapsed = time.time() - train_start
1162
+ if elapsed > MAX_TRAINING_TIME:
1163
+ early_path = str(out_path / "timeout_exit")
1164
+ model.decoder.eval()
1165
+ save_lora_adapter(model, early_path)
1166
+ yield f"[WARN] Training timed out after {int(elapsed)}s, saved to {early_path}"
1167
+ yield "[DONE]"
1168
+ unload_models(model)
1169
+ return
1170
+
1171
+ epoch_loss = 0.0
1172
+ num_updates = 0
1173
+ epoch_start = time.time()
1174
+
1175
+ for batch in loader:
1176
+ # Forward
1177
+ nb = device != "cpu"
1178
+ tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb)
1179
+ att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb)
1180
+ enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb)
1181
+ enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb)
1182
+ ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb)
1183
+
1184
+ bsz = tgt.shape[0]
1185
+
1186
+ # CFG dropout
1187
+ if null_cond is not None and cfg_ratio > 0:
1188
+ enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio)
1189
+
1190
+ # Timestep sampling
1191
+ t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma)
1192
+
1193
+ # Flow matching noise
1194
+ x1 = torch.randn_like(tgt)
1195
+ x0 = tgt
1196
+ t_ = t.unsqueeze(-1).unsqueeze(-1)
1197
+ xt = t_ * x1 + (1.0 - t_) * x0
1198
+
1199
+ if force_input_grads:
1200
+ xt = xt.requires_grad_(True)
1201
+
1202
+ # Decoder forward
1203
+ dec_out = model.decoder(
1204
+ hidden_states=xt,
1205
+ timestep=t,
1206
+ timestep_r=t,
1207
+ attention_mask=att,
1208
+ encoder_hidden_states=enc_hs,
1209
+ encoder_attention_mask=enc_mask,
1210
+ context_latents=ctx,
1211
+ )
1212
+
1213
+ flow = x1 - x0
1214
+ loss = F.mse_loss(dec_out[0], flow)
1215
+ loss = loss.float() # fp32 for stable backward
1216
+
1217
+ # NaN guard
1218
+ if torch.isnan(loss) or torch.isinf(loss):
1219
+ consecutive_nan += 1
1220
+ del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
1221
+ if consecutive_nan >= MAX_NAN:
1222
+ yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting"
1223
+ unload_models(model)
1224
+ return
1225
+ if acc_step > 0:
1226
+ optimizer.zero_grad(set_to_none=True)
1227
+ acc_loss = 0.0
1228
+ acc_step = 0
1229
+ continue
1230
+ consecutive_nan = 0
1231
+
1232
+ loss = loss / gradient_accumulation_steps
1233
+ loss.backward()
1234
+ acc_loss += loss.item()
1235
+ del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow
1236
+ acc_step += 1
1237
+
1238
+ if acc_step >= gradient_accumulation_steps:
1239
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
1240
+ optimizer.step()
1241
+ scheduler.step()
1242
+ global_step += 1
1243
+
1244
+ avg_loss = acc_loss * gradient_accumulation_steps / acc_step
1245
+
1246
+ if global_step % log_every == 0:
1247
+ current_lr = scheduler.get_last_lr()[0]
1248
+ yield (
1249
+ f"Epoch {epoch + 1}/{epochs}, "
1250
+ f"Step {global_step}, "
1251
+ f"Loss: {avg_loss:.4f}, "
1252
+ f"LR: {current_lr:.2e}"
1253
+ )
1254
+
1255
+ optimizer.zero_grad(set_to_none=True)
1256
+ epoch_loss += avg_loss
1257
+ num_updates += 1
1258
+ acc_loss = 0.0
1259
+ acc_step = 0
1260
+
1261
+ # Flush remainder
1262
+ if acc_step > 0:
1263
+ torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm)
1264
+ optimizer.step()
1265
+ scheduler.step()
1266
+ global_step += 1
1267
+ avg_loss = acc_loss * gradient_accumulation_steps / acc_step
1268
+ optimizer.zero_grad(set_to_none=True)
1269
+ epoch_loss += avg_loss
1270
+ num_updates += 1
1271
+ acc_loss = 0.0
1272
+ acc_step = 0
1273
+
1274
+ epoch_time = time.time() - epoch_start
1275
+ avg_epoch_loss = epoch_loss / max(num_updates, 1)
1276
+
1277
+ is_best = avg_epoch_loss < best_loss - 0.001
1278
+ if is_best:
1279
+ best_loss = avg_epoch_loss
1280
+ best_epoch = epoch + 1
1281
+
1282
+ best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else ""
1283
+ yield (
1284
+ f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, "
1285
+ f"Loss: {avg_epoch_loss:.4f}{best_str}"
1286
+ )
1287
+
1288
+ # Save best
1289
+ if is_best and epoch + 1 >= 10:
1290
+ best_path = str(out_path / "best")
1291
+ model.decoder.eval()
1292
+ save_lora_adapter(model, best_path)
1293
+ model.decoder.train()
1294
+ yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})"
1295
+
1296
+ # Periodic checkpoint
1297
+ if (epoch + 1) % save_every_n_epochs == 0:
1298
+ ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}")
1299
+ model.decoder.eval()
1300
+ save_lora_adapter(model, ckpt_path)
1301
+
1302
+ tstate = {
1303
+ "epoch": epoch + 1,
1304
+ "global_step": global_step,
1305
+ "optimizer_state_dict": optimizer.state_dict(),
1306
+ "scheduler_state_dict": scheduler.state_dict(),
1307
+ }
1308
+ os.makedirs(ckpt_path, exist_ok=True)
1309
+ torch.save(tstate, str(Path(ckpt_path) / "training_state.pt"))
1310
+ model.decoder.train()
1311
+ yield f"[OK] Checkpoint saved at epoch {epoch + 1}"
1312
+
1313
+ # Sanity check
1314
+ if global_step == 0:
1315
+ yield "[FAIL] Training completed 0 steps -- no batches processed"
1316
+ unload_models(model)
1317
+ return
1318
+
1319
+ # Final save
1320
+ final_path = str(out_path / "final")
1321
+ model.decoder.eval()
1322
+ save_lora_adapter(model, final_path)
1323
+
1324
+ final_loss = avg_epoch_loss if num_updates > 0 else 0.0
1325
+ best_note = ""
1326
+ if best_epoch > 0 and Path(out_path / "best").exists():
1327
+ best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})"
1328
+ yield (
1329
+ f"[OK] Training complete! LoRA saved to {final_path}{best_note}\n"
1330
+ f" For inference, set your LoRA path to: {final_path}"
1331
+ )
1332
+ yield "[DONE]"
1333
+ unload_models(model)
1334
+
1335
+
1336
+ # ============================================================================
1337
+ # ADAPTER LISTING
1338
+ # ============================================================================
1339
+
1340
+ def get_trained_loras(adapter_dir: str) -> List[str]:
1341
+ """List all saved LoRA adapter directories under adapter_dir."""
1342
+ result = []
1343
+ base = Path(adapter_dir)
1344
+ if not base.is_dir():
1345
+ return result
1346
+
1347
+ for root, dirs, files in os.walk(str(base)):
1348
+ for f in files:
1349
+ if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"):
1350
+ result.append(root)
1351
+ break
1352
+
1353
+ return sorted(set(result))