Spaces:
Running
Running
File size: 46,385 Bytes
72e4b69 dae27d4 72e4b69 dae27d4 72e4b69 ff239f5 dae27d4 a07b39d ff239f5 2dc2899 a07b39d 829ed0c a07b39d 32de701 a07b39d 829ed0c a07b39d 32de701 a07b39d 1549b91 829ed0c a07b39d dae27d4 72e4b69 2dc2899 dae27d4 829ed0c 625132a a07b39d 3c15b8b 829ed0c 3c15b8b a07b39d 625132a 72e4b69 2dc2899 dae27d4 72e4b69 32de701 92f884a dae27d4 2dc2899 32de701 92f884a 2dc2899 32de701 72e4b69 2dc2899 92f884a 2dc2899 32de701 92f884a 2dc2899 72e4b69 ff239f5 32de701 ff239f5 32de701 ff239f5 32de701 ff239f5 92f884a ff239f5 32de701 ff239f5 9c293be 92f884a ff239f5 72e4b69 a07b39d 72e4b69 2dc2899 dae27d4 72e4b69 2dc2899 72e4b69 dae27d4 4d9a556 72e4b69 a07b39d dae27d4 882ed5c 72e4b69 92f884a 2dc2899 72e4b69 dae27d4 72e4b69 b23b6b8 c0f2a13 72e4b69 92f884a 2dc2899 72e4b69 5c82a90 72e4b69 2dc2899 72e4b69 a07b39d b23b6b8 e62602f b23b6b8 e62602f a07b39d b23b6b8 e62602f a07b39d e62602f a07b39d e62602f b23b6b8 e62602f a07b39d d2ae079 a07b39d d2ae079 a07b39d d2ae079 a07b39d 829ed0c a07b39d 829ed0c a07b39d 829ed0c a07b39d 829ed0c a07b39d a4457c3 72e4b69 a07b39d 72e4b69 a07b39d 5c2e4e7 a07b39d 20382cb a07b39d 829ed0c 3c15b8b 829ed0c 72e4b69 20382cb 72e4b69 5c2e4e7 b23b6b8 e62602f 72e4b69 882ed5c 5c2e4e7 a07b39d 72e4b69 a07b39d d6a3e45 a07b39d 32de701 4d9a556 a4a86a8 a07b39d 625132a a07b39d 625132a a07b39d ff239f5 a07b39d 625132a a07b39d ff239f5 a07b39d a4a86a8 625132a a07b39d 57df0f6 625132a 4d9a556 625132a e62602f 32de701 e62602f a07b39d 1549b91 a07b39d ff239f5 a07b39d 956dc8c 1549b91 956dc8c bc97006 625132a bc97006 32de701 1549b91 956dc8c 1549b91 956dc8c 4d9a556 bc97006 956dc8c bc97006 956dc8c 625132a a07b39d ff239f5 d6a3e45 32de701 d6a3e45 9d04583 5dedf2e d6a3e45 5dedf2e 9d04583 5dedf2e 9d04583 5dedf2e 9d04583 32de701 5dedf2e 9d04583 d6a3e45 5dedf2e d6a3e45 ff239f5 d6a3e45 a4a86a8 32de701 a07b39d 829ed0c a07b39d ff239f5 a07b39d a4a86a8 32de701 a07b39d 32de701 e62602f 32de701 a07b39d 32de701 ff239f5 a07b39d 32de701 059d153 32de701 a07b39d d3618ec ff239f5 a07b39d ff239f5 a07b39d d3618ec a07b39d ff239f5 a07b39d e62602f a07b39d d3618ec a07b39d 04c031f a07b39d 0e27e49 a07b39d c0f2a13 a07b39d d3618ec a07b39d ff239f5 a07b39d ff239f5 a07b39d 32de701 a07b39d ff239f5 a07b39d 32de701 89af747 32de701 89af747 32de701 a07b39d 32de701 a07b39d e62602f 72e4b69 20382cb 72e4b69 ff239f5 72e4b69 20382cb 72e4b69 4d42fae 20382cb 4d42fae 72e4b69 20382cb 72e4b69 4d42fae 20382cb 4d42fae 20382cb 4d42fae 72e4b69 20382cb a07b39d 72e4b69 20382cb 72e4b69 35fbf3e 20382cb 35fbf3e 72e4b69 35fbf3e 8d2b494 9151fbf 8d2b494 35fbf3e 72e4b69 a07b39d 20382cb 9151fbf a07b39d 04c031f a07b39d 04c031f 81f54b1 a07b39d d6a3e45 72e4b69 ff239f5 2d3c27c a07b39d d6a3e45 2d3c27c 72e4b69 a07b39d 72e4b69 a07b39d ff239f5 a07b39d ff239f5 a07b39d 72e4b69 ff239f5 72e4b69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 | """ACE-Step 1.5 XL (CPU) - Gradio frontend + CLI for ace-server GGUF inference"""
import os
import sys
import time
import json
import argparse
import base64
import tempfile
import subprocess
import shutil
import string
import random
import requests
import logging
import threading
from train_engine import (
preprocess_audio,
train_lora_generator,
cancel_training,
_training_cancel,
get_trained_loras as _get_trained_loras_engine,
MAX_TRAINING_TIME,
)
logging.basicConfig(level=logging.INFO, format="%(message)s", stream=sys.stdout)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configurable limits (edit here, not buried in code)
# ---------------------------------------------------------------------------
MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
# MAX_TRAINING_TIME is imported from train_engine (single source of truth)
MAX_AUDIO_FILES = 50 # max number of training audio files per run
# ---------------------------------------------------------------------------
# Paths & constants
# ---------------------------------------------------------------------------
ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# Clean up old inference temp files (older than 1 hour) at startup
_CLEANUP_MAX_AGE = 3600 # seconds
try:
_now = time.time()
for _fname in os.listdir(OUTPUT_DIR):
if _fname.lower().endswith((".wav", ".mp3")):
_fpath = os.path.join(OUTPUT_DIR, _fname)
try:
if os.path.isfile(_fpath) and (_now - os.path.getmtime(_fpath)) > _CLEANUP_MAX_AGE:
os.remove(_fpath)
except OSError:
pass
except Exception:
pass
ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
ACE_SOURCE_DIR = "/app/ace-step-source"
ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
ADAPTER_DIR = os.environ.get("ACE_ADAPTER_DIR", "/app/adapters")
MODELS_DIR = os.environ.get("ACE_MODELS_DIR", "/app/models")
ACE_SERVER_BIN = "/app/ace-server"
# Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
_is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
_training_lock = threading.Lock()
# HF repo for on-demand GGUF downloads
GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
# ---------------------------------------------------------------------------
# ace-server helpers
# ---------------------------------------------------------------------------
def _server_ok():
try:
return requests.get(f"{ACE_SERVER}/health", timeout=5).status_code == 200
except Exception:
return False
def _get_props():
"""Fetch server properties (models, adapters)."""
try:
r = requests.get(f"{ACE_SERVER}/props", timeout=10)
if r.status_code == 200:
return r.json()
except Exception:
pass
return {}
def _poll_job(job_id, timeout=600, progress_cb=None, cancel_check=None):
"""Poll a job until done/error/timeout/cancelled. Returns (status, elapsed, data)."""
t0 = time.time()
while time.time() - t0 < timeout:
if cancel_check and cancel_check():
return "cancelled", time.time() - t0, None
try:
r = requests.get(f"{ACE_SERVER}/job", params={"id": job_id}, timeout=5)
data = r.json()
status = data.get("status", "unknown")
if progress_cb:
progress_cb(status, data)
if status in ("done", "error"):
return status, time.time() - t0, data
except Exception:
pass
time.sleep(1)
return "timeout", time.time() - t0, None
def _fetch_result(job_id, timeout=60):
"""Fetch result bytes/json for a completed job."""
r = requests.get(
f"{ACE_SERVER}/job",
params={"id": job_id, "result": 1},
timeout=timeout,
)
return r
def _caption_via_understand(audio_path, timeout=600, cancel_check=None):
"""Call ace-server /understand for a rich caption. Returns dict or None."""
fname = os.path.basename(audio_path)
try:
with open(audio_path, "rb") as f:
r = requests.post(
f"{ACE_SERVER}/understand",
files={"audio": (fname, f, "audio/mpeg")},
timeout=30,
)
if r.status_code != 200:
logger.warning("[Caption] %s: /understand %d: %s", fname, r.status_code, r.text[:200])
return None
job_id = r.json().get("id")
if not job_id:
return None
except Exception as exc:
logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
return None
status, elapsed, poll_data = _poll_job(job_id, timeout=timeout, cancel_check=cancel_check)
if status != "done":
logger.warning("[Caption] %s: /understand -> %s (%.0fs)", fname, status, elapsed)
return None
# Fetch result — /understand returns multipart/mixed (JSON + latents)
try:
r = _fetch_result(job_id, timeout=120)
if r.status_code != 200:
logger.warning("[Caption] %s: result fetch HTTP %d", fname, r.status_code)
return None
content_type = r.headers.get("Content-Type", "")
# multipart/mixed: extract JSON part (caption metadata)
if "multipart" in content_type:
boundary = None
for part in content_type.split(";"):
part = part.strip()
if part.startswith("boundary="):
boundary = part.split("=", 1)[1].strip('"')
if boundary:
import re
parts = r.content.split(f"--{boundary}".encode())
for part in parts:
if b"application/json" in part:
json_start = part.find(b"{")
json_end = part.rfind(b"}") + 1
if json_start >= 0 and json_end > json_start:
data = json.loads(part[json_start:json_end])
if isinstance(data, dict) and data.get("caption"):
logger.info("[Caption] %s: got caption (%d chars)",
fname, len(data["caption"]))
return data
# Plain JSON fallback
if r.text.strip():
data = r.json()
if isinstance(data, dict) and data.get("caption"):
return data
except Exception as exc:
logger.warning("[Caption] %s: result parse failed: %s", fname, exc)
logger.warning("[Caption] %s: no caption extracted from result", fname)
return None
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
adapter=None, lm_model=None, progress_cb=None):
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
t0 = time.time()
# -- Build LM request --
req = {"caption": caption or "upbeat electronic dance music"}
req["lyrics"] = lyrics if lyrics and lyrics.strip() else "[Instrumental]"
try:
if bpm and int(float(bpm)) > 0:
req["bpm"] = int(float(bpm))
if duration and float(duration) > 0:
req["duration"] = min(float(duration), 300)
if seed is not None and int(float(seed)) >= 0:
req["seed"] = int(float(seed))
if steps and int(float(steps)) > 0:
req["inference_steps"] = int(float(steps))
except (ValueError, TypeError):
pass
if adapter:
req["adapter"] = adapter
if lm_model:
req["model"] = lm_model
fmt = output_format if output_format in ("wav", "mp3") else "mp3"
synth_fmt = "wav16" if fmt == "wav" else "mp3"
suffix = f".{fmt}"
# -- LM phase --
if progress_cb:
progress_cb("lm_submit", None)
r = requests.post(f"{ACE_SERVER}/lm", json=req, timeout=30)
if r.status_code != 200:
raise RuntimeError(f"LM submit failed: {r.status_code} {r.text}")
lm_job_id = r.json().get("id")
if progress_cb:
progress_cb("lm_poll", {"job_id": lm_job_id})
lm_status, lm_elapsed, _ = _poll_job(lm_job_id, timeout=900)
if lm_status != "done":
raise RuntimeError(f"LM {lm_status} after {lm_elapsed:.0f}s")
# Fetch LM result
r = _fetch_result(lm_job_id)
lm_results = r.json()
if not isinstance(lm_results, list) or len(lm_results) == 0:
raise RuntimeError(f"LM returned no results: {lm_results}")
synth_request = lm_results[0]
# -- Synth phase --
synth_request["output_format"] = synth_fmt
if adapter:
synth_request["adapter"] = adapter
synth_request["synth_model"] = "acestep-v15-turbo-Q4_K_M.gguf"
if progress_cb:
progress_cb("synth_submit", None)
r = requests.post(f"{ACE_SERVER}/synth", json=synth_request, timeout=30)
if r.status_code != 200:
raise RuntimeError(f"Synth submit failed: {r.status_code} {r.text}")
synth_job_id = r.json().get("id")
if progress_cb:
progress_cb("synth_poll", {"job_id": synth_job_id})
synth_status, synth_elapsed, _ = _poll_job(synth_job_id, timeout=600)
if synth_status != "done":
raise RuntimeError(f"Synth {synth_status} after {synth_elapsed:.0f}s")
# Fetch audio
if progress_cb:
progress_cb("fetch", None)
r = _fetch_result(synth_job_id, timeout=60)
if r.status_code != 200:
raise RuntimeError(f"Audio fetch failed: {r.status_code}")
tmp = tempfile.NamedTemporaryFile(suffix=suffix, dir=OUTPUT_DIR, delete=False)
tmp.write(r.content)
tmp.close()
elapsed = time.time() - t0
msg = f"Done in {elapsed:.0f}s | {duration}s audio, {steps} steps, {fmt}"
return tmp.name, msg
# ---------------------------------------------------------------------------
# LM model scanning & on-demand download
# ---------------------------------------------------------------------------
DEFAULT_LM = "acestep-5Hz-lm-1.7B-Q8_0.gguf"
AVAILABLE_LM_MODELS = [
"acestep-5Hz-lm-1.7B-Q8_0.gguf",
"acestep-5Hz-lm-0.6B-Q8_0.gguf",
"acestep-5Hz-lm-4B-Q5_K_M.gguf",
]
def _scan_lm_models():
"""Return LM model choices. Installed shown as-is, others need download."""
installed = set()
if os.path.isdir(MODELS_DIR):
for f in os.listdir(MODELS_DIR):
if "-lm-" in f and f.endswith(".gguf"):
installed.add(f)
choices = []
for m in AVAILABLE_LM_MODELS:
if m in installed:
choices.append(m)
else:
choices.append(f"{m} [not installed]")
return choices
def _download_lm_model(filename):
"""Download a GGUF LM model from HF if not already present."""
dest = os.path.join(MODELS_DIR, filename)
if os.path.isfile(dest):
return dest
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id=GGUF_HF_REPO,
filename=filename,
local_dir=MODELS_DIR,
)
return path
except Exception as exc:
logger.error("Failed to download %s: %s", filename, exc)
return None
# ---------------------------------------------------------------------------
# LoRA listing for UI dropdowns
# ---------------------------------------------------------------------------
def _list_lora_choices():
"""Return list of LoRA choices for dropdown, including 'None'."""
choices = ["None (no LoRA)"]
if os.path.isdir(ADAPTER_DIR):
for d in os.listdir(ADAPTER_DIR):
if os.path.isdir(os.path.join(ADAPTER_DIR, d)):
choices.append(d)
return choices
# ---------------------------------------------------------------------------
# ace-server stop/start helpers
# ---------------------------------------------------------------------------
_ace_proc = None
def _stop_ace_server():
"""Stop ace-server process."""
global _ace_proc
logger.info("[ace-server] Stopping...")
if _ace_proc and _ace_proc.poll() is None:
_ace_proc.terminate()
try:
_ace_proc.wait(timeout=10)
except subprocess.TimeoutExpired:
_ace_proc.kill()
_ace_proc = None
logger.info("[ace-server] Stopped (tracked PID)")
else:
try:
subprocess.run(["pkill", "ace-server"],
stderr=subprocess.DEVNULL, stdout=subprocess.DEVNULL,
timeout=10)
logger.info("[ace-server] Stopped (pkill)")
except Exception:
pass
time.sleep(1)
def _start_ace_server(max_retries: int = 3, retry_delay: float = 5.0):
"""Start ace-server in background and wait for health.
Retries up to max_retries times with retry_delay seconds between attempts.
"""
global _ace_proc
for attempt in range(1, max_retries + 1):
logger.info(
"[ace-server] Starting (attempt %d/%d) with --adapters %s",
attempt, max_retries, ADAPTER_DIR,
)
try:
_ace_proc = subprocess.Popen(
[ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
"--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
)
except Exception as exc:
logger.error("[ace-server] Failed to start: %s", exc)
if attempt < max_retries:
time.sleep(retry_delay)
continue
return False
for _ in range(30):
if _server_ok():
logger.info("[ace-server] Healthy")
return True
time.sleep(2)
logger.warning("[ace-server] Health check timeout on attempt %d/%d", attempt, max_retries)
# Kill the failed process before retrying
if _ace_proc and _ace_proc.poll() is None:
_ace_proc.kill()
try:
_ace_proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
if attempt < max_retries:
time.sleep(retry_delay)
logger.error("[ace-server] Failed to start after %d attempts", max_retries)
return False
# ---------------------------------------------------------------------------
# CLI mode
# ---------------------------------------------------------------------------
def cli_main():
parser = argparse.ArgumentParser(
description="ACE-Step 1.5 XL (CPU) - CLI inference via ace-server",
)
parser.add_argument("caption", nargs="?", default="upbeat electronic dance music",
help="Music description / caption")
parser.add_argument("--lyrics", "-l", default="[Instrumental]",
help="Lyrics text (use '[Instrumental]' for no vocals)")
parser.add_argument("--bpm", type=int, default=120, help="Beats per minute")
parser.add_argument("--duration", "-d", type=float, default=10,
help="Duration in seconds (max 300)")
parser.add_argument("--steps", "-s", type=int, default=8,
help="Inference steps (1-32)")
parser.add_argument("--seed", type=int, default=-1,
help="Random seed (-1 for random)")
parser.add_argument("--format", "-f", choices=["wav", "mp3"], default="wav",
help="Output audio format")
parser.add_argument("--adapter", "-a", default=None,
help="LoRA adapter name")
parser.add_argument("-o", "--output", default=None,
help="Output file path (default: auto in outputs dir)")
parser.add_argument("--server", default=None,
help="ace-server URL (default: http://127.0.0.1:8085)")
args = parser.parse_args()
if args.server:
global ACE_SERVER
ACE_SERVER = args.server
if not _server_ok():
print(f"ERROR: ace-server not reachable at {ACE_SERVER}", file=sys.stderr)
sys.exit(1)
seed = args.seed if args.seed >= 0 else None
def cli_progress(phase, data):
phases = {
"lm_submit": "Submitting LM job...",
"lm_poll": f"LM generating (job {data['job_id']})..." if data else "LM generating...",
"synth_submit": "Submitting synth job...",
"synth_poll": f"Synthesizing (job {data['job_id']})..." if data else "Synthesizing...",
"fetch": "Fetching audio...",
}
msg = phases.get(phase, phase)
print(f" [{phase}] {msg}")
print(f"ACE-Step CLI | caption: {args.caption}")
print(f" lyrics: {args.lyrics} | bpm: {args.bpm} | duration: {args.duration}s "
f"| steps: {args.steps} | seed: {args.seed} | format: {args.format}")
try:
audio_path, status = _run_pipeline(
caption=args.caption,
lyrics=args.lyrics,
bpm=args.bpm,
duration=args.duration,
seed=seed,
steps=args.steps,
output_format=args.format,
adapter=args.adapter,
progress_cb=cli_progress,
)
except RuntimeError as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
# Move to requested output path if specified
if args.output:
out_dir = os.path.dirname(os.path.abspath(args.output))
os.makedirs(out_dir, exist_ok=True)
shutil.move(audio_path, args.output)
audio_path = args.output
print(f" {status}")
print(f" Output: {audio_path}")
# ---------------------------------------------------------------------------
# Gradio UI mode
# ---------------------------------------------------------------------------
def gradio_main():
import gradio as gr
import gc
# -- Persistent training log buffer (survives across yields) --
_train_log_lines = []
# -- Generate tab handler --
def generate_music(caption, lyrics, bpm, duration, seed,
steps, lora_select, lm_model_select,
progress=gr.Progress(track_tqdm=True)):
if not _training_lock.acquire(blocking=False):
return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
_training_lock.release()
if not _server_ok():
return None, "ace-server not running. Check logs."
if not lyrics or lyrics.strip() == "":
lyrics = "[Instrumental]"
actual_seed = None if seed is None or int(seed) < 0 else int(seed)
adapter = None if lora_select == "None (no LoRA)" else lora_select
lm_model_file = lm_model_select.replace(" [not installed]", "") if lm_model_select else None
if lm_model_file and "[not installed]" in (lm_model_select or ""):
_download_lm_model(lm_model_file)
lm_model = lm_model_file
progress_map = {
"lm_submit": (0.05, "Submitting LM job..."),
"lm_poll": (0.10, "LM generating..."),
"synth_submit": (0.40, "Submitting synth job..."),
"synth_poll": (0.50, "Synthesizing audio..."),
"fetch": (0.90, "Fetching audio..."),
}
def gr_progress(phase, data):
pct, desc = progress_map.get(phase, (0.5, phase))
if data and "job_id" in data:
desc += f" (job {data['job_id']})"
progress(pct, desc=desc)
try:
audio_path, status = _run_pipeline(
caption=caption,
lyrics=lyrics,
bpm=bpm,
duration=duration,
seed=actual_seed,
steps=steps,
output_format="mp3",
adapter=adapter,
lm_model=lm_model,
progress_cb=gr_progress,
)
return audio_path, status
except RuntimeError as e:
return None, str(e)
except Exception as e:
return None, f"Unexpected error: {e}"
# -- Server info helper --
def get_server_status():
if not _server_ok():
return "ace-server: OFFLINE"
props = _get_props()
lines = ["ace-server: ONLINE"]
if props:
lines.append(json.dumps(props, indent=2))
return "\n".join(lines)
# -- Training generator (direct integration, no subprocess) --
def train_lora_ui(audio_files, lora_name, epochs, lr, rank, use_lm_caption):
"""Generator that yields (train_log, train_btn_update, cancel_btn_update)."""
import gc as _gc
_train_log_lines.clear()
train_start = time.time()
def _log(msg):
elapsed = int(time.time() - train_start)
m, s = divmod(elapsed, 60)
h, m = divmod(m, 60)
ts = f"+{h}:{m:02d}:{s:02d}" if h else f"+{m:02d}:{s:02d}"
line = f"[{ts}] {msg}"
_train_log_lines.append(line)
logger.info(msg)
if len(_train_log_lines) > 2000:
_train_log_lines[:] = _train_log_lines[-1000:]
def _log_text():
return "\n".join(_train_log_lines)
# -- Validation --
if not audio_files:
_log("[FAIL] No audio files uploaded.")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
if len(audio_files) > MAX_AUDIO_FILES:
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
lora_name = (lora_name or "").strip() or "my-lora"
lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
epochs = max(1, min(int(epochs), 1000))
lr = float(lr)
rank = max(1, min(int(rank), 128))
work_dir = os.path.join(OUTPUT_DIR, "train_workspace", lora_name)
os.makedirs(work_dir, exist_ok=True)
audio_dir = os.path.join(work_dir, "audio_input")
if os.path.exists(audio_dir):
shutil.rmtree(audio_dir)
os.makedirs(audio_dir)
adapter_out = os.path.join(ADAPTER_DIR, lora_name)
os.makedirs(adapter_out, exist_ok=True)
# Copy uploaded audio files + check total duration
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
import librosa as _lr
total_dur = 0.0
accepted = 0
skipped_names = []
truncated_names = []
for f in audio_files:
src = f.name if hasattr(f, "name") else str(f)
fname = os.path.basename(src)
# .txt/.json sidecars: copy as caption files, skip duration check
if fname.lower().endswith((".txt", ".json")):
shutil.copy2(src, os.path.join(audio_dir, fname))
continue
try:
dur = _lr.get_duration(path=src)
except Exception:
dur = 0.0
if dur <= 0:
skipped_names.append(f"{fname} (invalid/empty)")
continue
remaining = MAX_TOTAL_AUDIO - total_dur
if remaining <= 0:
skipped_names.append(fname)
continue
if dur > remaining:
# Truncate this file to fit
import soundfile as _sf
y, sr = _lr.load(src, sr=None, mono=False)
max_samples = int(remaining * sr)
if y.ndim == 1:
y = y[:max_samples]
else:
y = y[:, :max_samples]
dst = os.path.join(audio_dir, fname)
_sf.write(dst, y.T if y.ndim > 1 else y, sr)
truncated_names.append(f"{fname} ({dur:.0f}s -> {remaining:.0f}s)")
total_dur += remaining
accepted += 1
else:
shutil.copy2(src, os.path.join(audio_dir, fname))
total_dur += dur
accepted += 1
if truncated_names:
_log(f"[WARN] Truncated: {', '.join(truncated_names)}")
if skipped_names:
_log(f"[WARN] Skipped (over {MAX_TOTAL_AUDIO/60:.0f} min cap): {', '.join(skipped_names)}")
_log(f"[INFO] Total audio: {total_dur:.0f}s ({total_dur/60:.1f} min), {accepted} files")
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
# Caption audio files without user-provided sidecars
audio_to_caption = []
for audio_fname in sorted(os.listdir(audio_dir)):
full_path = os.path.join(audio_dir, audio_fname)
if not os.path.isfile(full_path):
continue
ext = audio_fname.lower().rsplit(".", 1)[-1] if "." in audio_fname else ""
if ext in ("json", "txt"):
continue
stem = audio_fname.rsplit(".", 1)[0] if "." in audio_fname else audio_fname
sidecar_json = os.path.join(audio_dir, stem + ".json")
sidecar_txt = os.path.join(audio_dir, stem + ".txt")
if os.path.isfile(sidecar_json) or os.path.isfile(sidecar_txt):
_log(f" {audio_fname}: using caption file")
continue
audio_to_caption.append((audio_fname, full_path, sidecar_json))
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if audio_to_caption and use_lm_caption and _server_ok():
# --- Mode: GGUF LM captioning (best quality, 5h timeout per file) ---
LM_TIMEOUT = 18000 # 5h per file
est_total = int(total_dur * 7 + len(audio_to_caption) * 600)
if est_total > LM_TIMEOUT:
_log(f"[WARN] Estimated {est_total // 60} min exceeds 5h, switching to fast captioning")
use_lm_caption = False
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
else:
_log(f"[INFO] LM captioning {len(audio_to_caption)} files (5h timeout per file)...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
for audio_fname, full_path, sidecar_json in audio_to_caption:
if _training_cancel.is_set():
break
_log(f" {audio_fname}: LM captioning...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
caption_data = _caption_via_understand(
full_path, timeout=LM_TIMEOUT,
cancel_check=lambda: _training_cancel.is_set(),
)
if caption_data:
bpm_s = caption_data.get("bpm", "?")
key_s = caption_data.get("keyscale", caption_data.get("key", "?"))
_log(f" {audio_fname}: OK (BPM={bpm_s}, key={key_s})")
with open(sidecar_json, "w") as cj:
json.dump(caption_data, cj)
else:
_log(f" {audio_fname}: LM failed")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if audio_to_caption and not use_lm_caption:
# --- Mode: Fast captioning (CLAP + Whisper + librosa) ---
_log(f"[INFO] Fast captioning {len(audio_to_caption)} files "
f"(CLAP tags + lyrics + BPM)...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
try:
from caption_fast import caption_audio, unload_caption_models
for audio_fname, full_path, sidecar_json in audio_to_caption:
if _training_cancel.is_set():
break
_log(f" {audio_fname}: analyzing...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
try:
result = caption_audio(full_path)
_log(f" {audio_fname}: {result.get('caption', '')[:60]}")
if result.get("lyrics") and result["lyrics"] != "[Instrumental]":
_log(f" {audio_fname}: lyrics extracted ({len(result['lyrics'])} chars)")
with open(sidecar_json, "w") as cj:
json.dump(result, cj)
except Exception as cap_exc:
_log(f" {audio_fname}: fast caption failed: {cap_exc}")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
unload_caption_models()
_gc.collect()
except ImportError:
_log("[WARN] Fast captioning not available, using librosa fallback")
for audio_fname, full_path, sidecar_json in audio_to_caption:
try:
y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
tempo_arr, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
bpm_val = int(round(float(
tempo_arr.item() if hasattr(tempo_arr, 'item') else tempo_arr)))
fallback = {"caption": audio_fname.rsplit(".", 1)[0],
"bpm": str(bpm_val), "key": "", "signature": "4/4",
"lyrics": "[Instrumental]"}
with open(sidecar_json, "w") as cj:
json.dump(fallback, cj)
_log(f" {audio_fname}: librosa BPM={bpm_val}")
except Exception as exc:
_log(f" {audio_fname}: failed: {exc}")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if _training_cancel.is_set():
_training_cancel.clear()
_log("[CANCELLED] Stopped")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
shutil.rmtree(work_dir, ignore_errors=True)
return
# Stop ace-server before training (frees memory)
_training_lock.acquire()
_log("[INFO] Stopping ace-server for training...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
_stop_ace_server()
_gc.collect()
_cleanup_done = False
try:
# -- Phase 1: Preprocessing (runs in thread for live progress) --
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
_preprocess_log_len = len(_train_log_lines)
def preprocess_progress(current, total, desc):
_log(f" {desc} ({current}/{total})")
_preprocess_result = [None]
_preprocess_error = [None]
def _run_preprocess():
try:
_preprocess_result[0] = preprocess_audio(
audio_dir=audio_dir,
output_dir=preprocessed_dir,
checkpoint_dir=ACE_CHECKPOINT_DIR,
device="cpu",
variant="turbo",
max_duration=float(MAX_TOTAL_AUDIO),
progress_callback=preprocess_progress,
cancel_check=lambda: _training_cancel.is_set(),
)
except Exception as exc:
_preprocess_error[0] = exc
_log("[Step 1/2] Encoding audio → training data (VAE + text encoder)...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
t = threading.Thread(target=_run_preprocess, daemon=True)
t.start()
while t.is_alive():
t.join(timeout=3)
if _training_cancel.is_set():
_training_cancel.clear()
_log("[CANCELLED] Stopped during preprocessing")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
if len(_train_log_lines) > _preprocess_log_len:
_preprocess_log_len = len(_train_log_lines)
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if _preprocess_error[0]:
raise _preprocess_error[0]
result = _preprocess_result[0]
processed = result.get("processed", 0)
failed = result.get("failed", 0)
total = result.get("total", 0)
_log(f"[OK] Preprocessed: {processed}/{total} files (failed: {failed})")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if processed == 0:
_log("[FAIL] No files preprocessed successfully. Cannot train.")
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
return
_gc.collect()
# -- Phase 2: Training (random 60s crops for speed + augmentation) --
_log("[Step 2/2] Training LoRA...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
for msg in train_lora_generator(
dataset_dir=preprocessed_dir,
output_dir=adapter_out,
checkpoint_dir=ACE_CHECKPOINT_DIR,
epochs=epochs,
lr=lr,
rank=rank,
alpha=rank * 2,
dropout=0.0,
batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=100,
weight_decay=0.01,
max_grad_norm=1.0,
save_every_n_epochs=0,
seed=42,
variant="turbo",
device="cpu",
chunk_duration=60,
log_every=5,
):
elapsed = time.time() - train_start
if elapsed > MAX_TRAINING_TIME:
_log(f"[WARN] Training timed out after {int(elapsed)}s")
cancel_training()
break
_log(msg)
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
if msg.strip() == "[DONE]":
break
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
except GeneratorExit:
_training_cancel.set()
logger.info("Generator closed by Gradio, cleaning up")
_cleanup_done = True
_training_lock.release()
_gc.collect()
_start_ace_server()
shutil.rmtree(work_dir, ignore_errors=True)
return
except Exception as exc:
_log(f"[FAIL] Training error: {exc}")
import traceback
_log(traceback.format_exc())
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
finally:
if not _cleanup_done:
_training_lock.release()
_log("[INFO] Restarting ace-server...")
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
_gc.collect()
ok = _start_ace_server()
if ok:
_log("[OK] ace-server restarted successfully")
else:
_log("[WARN] ace-server may not have restarted -- check logs")
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
if os.path.isfile(adapter_safetensors):
import zipfile
tmp_zip = tempfile.NamedTemporaryFile(
suffix=".zip",
prefix=f"{lora_name}_",
delete=False,
)
tmp_zip.close()
with zipfile.ZipFile(tmp_zip.name, "w", zipfile.ZIP_DEFLATED) as zf:
zf.write(adapter_safetensors, f"{lora_name}/adapter_model.safetensors")
adapter_config = os.path.join(adapter_out, "adapter_config.json")
if os.path.isfile(adapter_config):
zf.write(adapter_config, f"{lora_name}/adapter_config.json")
# Include generated captions if they exist
caption_count = 0
if os.path.isdir(audio_dir):
for cf in sorted(os.listdir(audio_dir)):
if cf.endswith(".json"):
zf.write(os.path.join(audio_dir, cf),
f"{lora_name}/captions/{cf}")
caption_count += 1
_log(f"[OK] LoRA saved: {lora_name}" +
(f" ({caption_count} captions included)" if caption_count else ""))
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_zip.name, visible=True)
else:
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
shutil.rmtree(work_dir, ignore_errors=True)
# -- Cancel handler --
def _on_cancel():
cancel_training()
logger.info("Cancel requested by user")
return "Cancelling..."
# -- Build LM model choices --
def _lm_model_choices():
return _scan_lm_models()
# -- Build UI --
CSS = """
.compact-row { gap: 8px !important; }
.status-box textarea { font-family: monospace; font-size: 13px; overflow-y: auto !important; }
"""
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
with gr.Tabs():
# ============================================================
# Tab 1: Generate Music
# ============================================================
with gr.Tab("Generate Music"):
gr.Markdown("**[ACE-Step 1.5 XL](https://github.com/ace-step/ACE-Step-1.5)** GGUF Q4_K_M via [acestep.cpp](https://github.com/ServeurpersoCom/acestep.cpp) | ~5 min for 10s audio")
with gr.Row(elem_classes="compact-row"):
with gr.Column(scale=3):
audio_out = gr.Audio(label="Output", type="filepath")
status = gr.Textbox(
label="Status",
interactive=False,
lines=1,
elem_classes="status-box",
)
caption = gr.Textbox(
label="Music Description",
lines=2,
value="upbeat electronic dance music, energetic synth leads",
)
lyrics = gr.Textbox(
label="Lyrics",
lines=3,
value="[Instrumental]",
placeholder="Enter lyrics here, or leave empty for instrumental (no vocals)",
)
with gr.Column(scale=2):
gen_btn = gr.Button("Generate Music", variant="primary")
with gr.Row(elem_classes="compact-row"):
bpm = gr.Number(label="BPM", value=120, minimum=0, maximum=300)
seed = gr.Number(label="Seed (-1=random)", value=-1)
with gr.Row(elem_classes="compact-row"):
duration = gr.Slider(
label="Duration (s)", minimum=10, maximum=120,
value=10, step=5,
)
steps = gr.Slider(
label="Steps (8 for turbo)", minimum=1, maximum=32,
value=8, step=1, interactive=False,
)
with gr.Row(elem_classes="compact-row"):
lora_select = gr.Dropdown(
label="LoRA", choices=_list_lora_choices(),
value="None (no LoRA)",
allow_custom_value=True,
)
lm_model_select = gr.Dropdown(
label="LM Model", choices=_lm_model_choices(),
value=DEFAULT_LM,
)
gen_btn.click(
fn=generate_music,
inputs=[caption, lyrics, bpm, duration,
seed, steps, lora_select, lm_model_select],
outputs=[audio_out, status],
api_name="generate",
)
# ============================================================
# Tab 2: Train LoRA
# ============================================================
with gr.Tab("Train LoRA"):
gr.Markdown("LoRA training ported from [Side-Step](https://github.com/koda-dernet/Side-Step) | Model: [ACE-Step 1.5](https://github.com/ace-step/ACE-Step-1.5) | ~8h for 3 files @ 200 epochs")
with gr.Row(elem_classes="compact-row"):
with gr.Column(scale=3):
train_log = gr.Textbox(
label="Training Log",
interactive=False,
lines=12,
max_lines=50,
autoscroll=True,
elem_classes="status-box",
)
train_output_file = gr.File(label="Trained LoRA (download)", visible=False)
train_audio = gr.File(
label="Training Audio — max 30 min total, ~2 min/epoch on CPU (optional caption .txt)",
file_count="multiple",
file_types=["audio", ".txt", ".json"],
height=120,
)
with gr.Column(scale=2):
with gr.Row(elem_classes="compact-row"):
train_btn = gr.Button("Train", variant="primary", scale=2)
cancel_btn = gr.Button("Cancel Training", variant="stop", visible=False, scale=1)
lora_name = gr.Textbox(label="LoRA Name", value="my-lora")
train_epochs = gr.Slider(
label="Epochs (200 recommended ~6h on CPU, best 500)",
minimum=1, maximum=1000,
value=200, step=1,
)
train_lr = gr.Number(label="Learning Rate", value=3e-4)
train_rank = gr.Slider(
label="Rank (r)", minimum=1, maximum=128,
value=16, step=1,
)
use_lm_caption = gr.Checkbox(
label="Use LM captioning (best quality, ~30 min/file)",
value=False,
)
# Button swap on click (separate handler, like rvc-beatrice)
# This fires immediately so user sees Cancel even if training
# queues behind concurrency_limit=1
train_btn.click(
lambda: (gr.Button(visible=False), gr.Button(visible=True)),
outputs=[train_btn, cancel_btn],
)
# Training generator -- yields (log, train_btn, cancel_btn, output_file)
train_event = train_btn.click(
train_lora_ui,
inputs=[train_audio, lora_name, train_epochs, train_lr, train_rank, use_lm_caption],
outputs=[train_log, train_btn, cancel_btn, train_output_file],
api_name="train_lora",
concurrency_limit=1,
)
# After training completes, restore buttons and refresh LoRA dropdown
# This ensures cleanup even if the user navigated away
def _post_training():
return (
gr.Button(visible=True),
gr.Button(visible=False),
gr.Dropdown(choices=_list_lora_choices()),
)
train_event.then(
_post_training,
outputs=[train_btn, cancel_btn, lora_select],
)
# Cancel: set the flag, update status
cancel_btn.click(
_on_cancel,
outputs=[train_log],
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
mcp_server=True,
css=CSS,
)
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
# If any CLI arguments besides the script name, run CLI mode
# (Gradio sets no extra args; start.sh calls `python3 /app/app.py`)
if len(sys.argv) > 1:
cli_main()
else:
gradio_main()
|