Spaces:
Running
Running
add auto-captioning (BPM/key/signature via librosa), add librosa+mutagen deps
Browse files- Dockerfile +1 -1
- train_engine.py +271 -2
Dockerfile
CHANGED
|
@@ -78,7 +78,7 @@ RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/w
|
|
| 78 |
"gradio[mcp]==5.29.0" requests torch safetensors \
|
| 79 |
"transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
|
| 80 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 81 |
-
einops vector_quantize_pytorch
|
| 82 |
|
| 83 |
# Clone ACE-Step repo for training module
|
| 84 |
RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
|
|
|
|
| 78 |
"gradio[mcp]==5.29.0" requests torch safetensors \
|
| 79 |
"transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
|
| 80 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 81 |
+
einops vector_quantize_pytorch librosa mutagen
|
| 82 |
|
| 83 |
# Clone ACE-Step repo for training module
|
| 84 |
RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
|
train_engine.py
CHANGED
|
@@ -19,9 +19,11 @@ import logging
|
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
import random
|
|
|
|
| 22 |
import sys
|
| 23 |
import time
|
| 24 |
import types
|
|
|
|
| 25 |
from dataclasses import dataclass, field
|
| 26 |
from pathlib import Path
|
| 27 |
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
|
@@ -771,6 +773,259 @@ def _detect_max_duration(files: List[Path]) -> float:
|
|
| 771 |
return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
|
| 772 |
|
| 773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
# ============================================================================
|
| 775 |
# PREPROCESSING (2-pass sequential)
|
| 776 |
# ============================================================================
|
|
@@ -846,8 +1101,22 @@ def preprocess_audio(
|
|
| 846 |
lat_len = target_latents.shape[1]
|
| 847 |
att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
|
| 848 |
|
| 849 |
-
caption
|
| 850 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 851 |
text_prompt = caption
|
| 852 |
|
| 853 |
with torch.no_grad():
|
|
|
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
import random
|
| 22 |
+
import re
|
| 23 |
import sys
|
| 24 |
import time
|
| 25 |
import types
|
| 26 |
+
import unicodedata
|
| 27 |
from dataclasses import dataclass, field
|
| 28 |
from pathlib import Path
|
| 29 |
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
|
|
|
| 773 |
return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
|
| 774 |
|
| 775 |
|
| 776 |
+
# ============================================================================
|
| 777 |
+
# AUDIO ANALYSIS (ported from Side-Step, faf mode only -- CPU, ~2-3s/file)
|
| 778 |
+
# ============================================================================
|
| 779 |
+
|
| 780 |
+
# Krumhansl key profile for single-profile key detection
|
| 781 |
+
_KEY_PROFILE_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
|
| 782 |
+
2.52, 5.19, 2.39, 3.66, 2.29, 2.88]
|
| 783 |
+
_KEY_PROFILE_MINOR = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
|
| 784 |
+
2.54, 4.75, 3.98, 2.69, 3.34, 3.17]
|
| 785 |
+
_PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
|
| 786 |
+
"F#", "G", "G#", "A", "A#", "B"]
|
| 787 |
+
|
| 788 |
+
# Filename pattern: "Artist - Title"
|
| 789 |
+
_FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
|
| 790 |
+
|
| 791 |
+
|
| 792 |
+
def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
|
| 793 |
+
"""Fold BPM into the musical sweet-spot range [lo, hi]."""
|
| 794 |
+
if bpm <= 0:
|
| 795 |
+
return bpm
|
| 796 |
+
candidate = bpm
|
| 797 |
+
while candidate > hi:
|
| 798 |
+
candidate /= 2.0
|
| 799 |
+
while candidate < lo:
|
| 800 |
+
candidate *= 2.0
|
| 801 |
+
if candidate < lo or candidate > hi:
|
| 802 |
+
return bpm
|
| 803 |
+
return candidate
|
| 804 |
+
|
| 805 |
+
|
| 806 |
+
def _detect_bpm_faf(y, sr) -> Optional[int]:
|
| 807 |
+
"""Detect BPM using librosa beat_track + octave correction (faf mode)."""
|
| 808 |
+
import librosa
|
| 809 |
+
import numpy as np
|
| 810 |
+
|
| 811 |
+
try:
|
| 812 |
+
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 813 |
+
val = float(np.atleast_1d(tempo)[0])
|
| 814 |
+
if val > 0:
|
| 815 |
+
return int(round(_octave_correct_bpm(val)))
|
| 816 |
+
except Exception:
|
| 817 |
+
pass
|
| 818 |
+
return None
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def _detect_key_faf(y, sr) -> Optional[str]:
|
| 822 |
+
"""Detect key using Krumhansl profile on chroma_cens (faf mode)."""
|
| 823 |
+
import librosa
|
| 824 |
+
import numpy as np
|
| 825 |
+
|
| 826 |
+
try:
|
| 827 |
+
y_harmonic = librosa.effects.harmonic(y, margin=2.0)
|
| 828 |
+
chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr)
|
| 829 |
+
|
| 830 |
+
# Energy-weighted average chroma
|
| 831 |
+
rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512)
|
| 832 |
+
rms_vec = rms[0]
|
| 833 |
+
min_len = min(chroma.shape[1], len(rms_vec))
|
| 834 |
+
chroma = chroma[:, :min_len]
|
| 835 |
+
rms_vec = rms_vec[:min_len]
|
| 836 |
+
weights = rms_vec / (rms_vec.sum() + 1e-10)
|
| 837 |
+
chroma_avg = (chroma * weights[None, :]).sum(axis=1)
|
| 838 |
+
s = chroma_avg.sum()
|
| 839 |
+
if s == 0:
|
| 840 |
+
return None
|
| 841 |
+
chroma_avg = chroma_avg / s
|
| 842 |
+
|
| 843 |
+
major_norm = np.array(_KEY_PROFILE_MAJOR)
|
| 844 |
+
major_norm = major_norm / major_norm.sum()
|
| 845 |
+
minor_norm = np.array(_KEY_PROFILE_MINOR)
|
| 846 |
+
minor_norm = minor_norm / minor_norm.sum()
|
| 847 |
+
|
| 848 |
+
best_corr = -2.0
|
| 849 |
+
best_key = "C major"
|
| 850 |
+
for shift in range(12):
|
| 851 |
+
rotated = np.roll(chroma_avg, -shift)
|
| 852 |
+
corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1])
|
| 853 |
+
if corr_maj > best_corr:
|
| 854 |
+
best_corr = corr_maj
|
| 855 |
+
best_key = f"{_PITCH_CLASSES[shift]} major"
|
| 856 |
+
corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1])
|
| 857 |
+
if corr_min > best_corr:
|
| 858 |
+
best_corr = corr_min
|
| 859 |
+
best_key = f"{_PITCH_CLASSES[shift]} minor"
|
| 860 |
+
return best_key
|
| 861 |
+
except Exception:
|
| 862 |
+
return None
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
def _detect_time_signature_faf() -> str:
|
| 866 |
+
"""Faf mode returns hardcoded 4/4 (correct ~80%+ of the time)."""
|
| 867 |
+
return "4/4"
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def _sanitize_tag(value: str) -> str:
|
| 871 |
+
"""Normalize a tag value: NFKC normalize, strip invisible chars."""
|
| 872 |
+
value = unicodedata.normalize("NFKC", value)
|
| 873 |
+
value = (
|
| 874 |
+
value
|
| 875 |
+
.replace("", "").replace("", "")
|
| 876 |
+
.replace("", "").replace("", "")
|
| 877 |
+
.replace("", "").replace("", "")
|
| 878 |
+
.replace("", "").replace("", "")
|
| 879 |
+
.replace("", "")
|
| 880 |
+
)
|
| 881 |
+
value = "".join(
|
| 882 |
+
c for c in value
|
| 883 |
+
if c in ("\n", "\r", "\t", " ") or unicodedata.category(c)[0] != "C"
|
| 884 |
+
)
|
| 885 |
+
return value.strip()
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
def _extract_metadata_from_tags(audio_path: Path) -> tuple:
|
| 889 |
+
"""Extract (title, artist) from audio tags via mutagen, fallback to filename."""
|
| 890 |
+
title, artist = None, None
|
| 891 |
+
try:
|
| 892 |
+
import mutagen
|
| 893 |
+
mf = mutagen.File(str(audio_path))
|
| 894 |
+
if mf is not None and mf.tags is not None:
|
| 895 |
+
# ID3 (MP3, AIFF)
|
| 896 |
+
for key in ("TIT2",):
|
| 897 |
+
val = mf.tags.get(key)
|
| 898 |
+
if val:
|
| 899 |
+
title = _sanitize_tag(str(val))
|
| 900 |
+
break
|
| 901 |
+
for key in ("TPE1", "TPE2"):
|
| 902 |
+
val = mf.tags.get(key)
|
| 903 |
+
if val:
|
| 904 |
+
artist = _sanitize_tag(str(val))
|
| 905 |
+
break
|
| 906 |
+
# Vorbis (FLAC, OGG) and MP4 atoms
|
| 907 |
+
if title is None:
|
| 908 |
+
for key in ("title", "\xa9nam"):
|
| 909 |
+
vals = mf.tags.get(key)
|
| 910 |
+
if vals:
|
| 911 |
+
raw = str(vals[0]) if isinstance(vals, list) else str(vals)
|
| 912 |
+
title = _sanitize_tag(raw)
|
| 913 |
+
break
|
| 914 |
+
if artist is None:
|
| 915 |
+
for key in ("artist", "\xa9ART", "albumartist", "aART"):
|
| 916 |
+
vals = mf.tags.get(key)
|
| 917 |
+
if vals:
|
| 918 |
+
raw = str(vals[0]) if isinstance(vals, list) else str(vals)
|
| 919 |
+
artist = _sanitize_tag(raw)
|
| 920 |
+
break
|
| 921 |
+
except Exception:
|
| 922 |
+
pass
|
| 923 |
+
|
| 924 |
+
# Fallback to filename parsing
|
| 925 |
+
if not title:
|
| 926 |
+
stem = audio_path.stem
|
| 927 |
+
match = _FILENAME_RE.match(stem)
|
| 928 |
+
if match:
|
| 929 |
+
artist = artist or match.group(1).strip()
|
| 930 |
+
title = match.group(2).strip()
|
| 931 |
+
else:
|
| 932 |
+
title = stem.strip()
|
| 933 |
+
|
| 934 |
+
return title or audio_path.stem, artist or ""
|
| 935 |
+
|
| 936 |
+
|
| 937 |
+
def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
|
| 938 |
+
"""Analyze an audio file and build a training caption.
|
| 939 |
+
|
| 940 |
+
Uses faf mode only (CPU, ~2-3s per file): librosa beat_track for BPM,
|
| 941 |
+
Krumhansl chroma for key, hardcoded 4/4 time signature.
|
| 942 |
+
|
| 943 |
+
Args:
|
| 944 |
+
audio_path: Path to the audio file.
|
| 945 |
+
mode: Analysis mode (only "faf" supported).
|
| 946 |
+
|
| 947 |
+
Returns:
|
| 948 |
+
Dict with keys: caption, bpm, key, signature, lyrics, title, artist.
|
| 949 |
+
"""
|
| 950 |
+
import librosa
|
| 951 |
+
import numpy as np
|
| 952 |
+
|
| 953 |
+
audio_path = Path(audio_path)
|
| 954 |
+
|
| 955 |
+
# Load audio once, reuse for all detectors
|
| 956 |
+
try:
|
| 957 |
+
y, sr = librosa.load(str(audio_path), sr=None, mono=True)
|
| 958 |
+
# Trim silence + peak normalize
|
| 959 |
+
y_trimmed, _ = librosa.effects.trim(y, top_db=30)
|
| 960 |
+
if len(y_trimmed) >= sr:
|
| 961 |
+
y = y_trimmed
|
| 962 |
+
peak = np.max(np.abs(y))
|
| 963 |
+
if peak > 0:
|
| 964 |
+
y = y / peak
|
| 965 |
+
except Exception as exc:
|
| 966 |
+
logger.warning("Could not load audio for analysis: %s: %s", audio_path.name, exc)
|
| 967 |
+
title, artist = _extract_metadata_from_tags(audio_path)
|
| 968 |
+
return {
|
| 969 |
+
"caption": f"A track by {artist}" if artist else f"A track titled {title}",
|
| 970 |
+
"bpm": None, "key": None, "signature": "4/4",
|
| 971 |
+
"lyrics": "[Instrumental]", "title": title, "artist": artist,
|
| 972 |
+
}
|
| 973 |
+
|
| 974 |
+
bpm = _detect_bpm_faf(y, sr)
|
| 975 |
+
key = _detect_key_faf(y, sr)
|
| 976 |
+
signature = _detect_time_signature_faf()
|
| 977 |
+
title, artist = _extract_metadata_from_tags(audio_path)
|
| 978 |
+
|
| 979 |
+
# Build caption string for ACE-Step training
|
| 980 |
+
parts = ["A"]
|
| 981 |
+
if artist:
|
| 982 |
+
parts.append(f"track by {artist}")
|
| 983 |
+
else:
|
| 984 |
+
parts.append("track")
|
| 985 |
+
if bpm:
|
| 986 |
+
parts.append(f"at {bpm} BPM")
|
| 987 |
+
if key:
|
| 988 |
+
parts.append(f"in {key}")
|
| 989 |
+
parts.append(f"{signature} time")
|
| 990 |
+
caption = " ".join(parts)
|
| 991 |
+
|
| 992 |
+
lyrics = "[Instrumental]"
|
| 993 |
+
|
| 994 |
+
result = {
|
| 995 |
+
"caption": caption,
|
| 996 |
+
"bpm": bpm,
|
| 997 |
+
"key": key,
|
| 998 |
+
"signature": signature,
|
| 999 |
+
"lyrics": lyrics,
|
| 1000 |
+
"title": title,
|
| 1001 |
+
"artist": artist,
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
logger.info("Auto-caption for %s: %s", audio_path.name, caption)
|
| 1005 |
+
return result
|
| 1006 |
+
|
| 1007 |
+
|
| 1008 |
+
def _write_caption_sidecar(audio_path: Path, analysis: Dict[str, Any]) -> Path:
|
| 1009 |
+
"""Write analysis results as a .json sidecar next to the audio file."""
|
| 1010 |
+
sidecar_path = audio_path.with_suffix(".json")
|
| 1011 |
+
with open(sidecar_path, "w", encoding="utf-8") as f:
|
| 1012 |
+
json.dump(analysis, f, indent=2, ensure_ascii=False)
|
| 1013 |
+
logger.info("Wrote caption sidecar: %s", sidecar_path)
|
| 1014 |
+
return sidecar_path
|
| 1015 |
+
|
| 1016 |
+
|
| 1017 |
+
def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
|
| 1018 |
+
"""Read an existing .json caption sidecar if it exists."""
|
| 1019 |
+
sidecar_path = audio_path.with_suffix(".json")
|
| 1020 |
+
if not sidecar_path.is_file():
|
| 1021 |
+
return None
|
| 1022 |
+
try:
|
| 1023 |
+
with open(sidecar_path, "r", encoding="utf-8") as f:
|
| 1024 |
+
return json.load(f)
|
| 1025 |
+
except Exception:
|
| 1026 |
+
return None
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
# ============================================================================
|
| 1030 |
# PREPROCESSING (2-pass sequential)
|
| 1031 |
# ============================================================================
|
|
|
|
| 1101 |
lat_len = target_latents.shape[1]
|
| 1102 |
att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
|
| 1103 |
|
| 1104 |
+
# Auto-caption: read existing sidecar or analyze
|
| 1105 |
+
sidecar = _read_caption_sidecar(af)
|
| 1106 |
+
if sidecar and sidecar.get("caption"):
|
| 1107 |
+
caption = sidecar["caption"]
|
| 1108 |
+
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 1109 |
+
logger.info("Using existing caption for %s", af.name)
|
| 1110 |
+
else:
|
| 1111 |
+
try:
|
| 1112 |
+
analysis = analyze_and_caption(str(af))
|
| 1113 |
+
caption = analysis["caption"]
|
| 1114 |
+
lyrics = analysis.get("lyrics", "[Instrumental]")
|
| 1115 |
+
_write_caption_sidecar(af, analysis)
|
| 1116 |
+
except Exception as exc:
|
| 1117 |
+
logger.warning("Auto-caption failed for %s: %s, using filename", af.name, exc)
|
| 1118 |
+
caption = af.stem
|
| 1119 |
+
lyrics = "[Instrumental]"
|
| 1120 |
text_prompt = caption
|
| 1121 |
|
| 1122 |
with torch.no_grad():
|