Nekochu commited on
Commit
1d42836
·
1 Parent(s): c0f2a13

add auto-captioning (BPM/key/signature via librosa), add librosa+mutagen deps

Browse files
Files changed (2) hide show
  1. Dockerfile +1 -1
  2. train_engine.py +271 -2
Dockerfile CHANGED
@@ -78,7 +78,7 @@ RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/w
78
  "gradio[mcp]==5.29.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
- einops vector_quantize_pytorch
82
 
83
  # Clone ACE-Step repo for training module
84
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
 
78
  "gradio[mcp]==5.29.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
+ einops vector_quantize_pytorch librosa mutagen
82
 
83
  # Clone ACE-Step repo for training module
84
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
train_engine.py CHANGED
@@ -19,9 +19,11 @@ import logging
19
  import math
20
  import os
21
  import random
 
22
  import sys
23
  import time
24
  import types
 
25
  from dataclasses import dataclass, field
26
  from pathlib import Path
27
  from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
@@ -771,6 +773,259 @@ def _detect_max_duration(files: List[Path]) -> float:
771
  return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
772
 
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  # ============================================================================
775
  # PREPROCESSING (2-pass sequential)
776
  # ============================================================================
@@ -846,8 +1101,22 @@ def preprocess_audio(
846
  lat_len = target_latents.shape[1]
847
  att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
848
 
849
- caption = af.stem
850
- lyrics = "[Instrumental]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
851
  text_prompt = caption
852
 
853
  with torch.no_grad():
 
19
  import math
20
  import os
21
  import random
22
+ import re
23
  import sys
24
  import time
25
  import types
26
+ import unicodedata
27
  from dataclasses import dataclass, field
28
  from pathlib import Path
29
  from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
 
773
  return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION)
774
 
775
 
776
+ # ============================================================================
777
+ # AUDIO ANALYSIS (ported from Side-Step, faf mode only -- CPU, ~2-3s/file)
778
+ # ============================================================================
779
+
780
+ # Krumhansl key profile for single-profile key detection
781
+ _KEY_PROFILE_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
782
+ 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]
783
+ _KEY_PROFILE_MINOR = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
784
+ 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]
785
+ _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
786
+ "F#", "G", "G#", "A", "A#", "B"]
787
+
788
+ # Filename pattern: "Artist - Title"
789
+ _FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
790
+
791
+
792
+ def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
793
+ """Fold BPM into the musical sweet-spot range [lo, hi]."""
794
+ if bpm <= 0:
795
+ return bpm
796
+ candidate = bpm
797
+ while candidate > hi:
798
+ candidate /= 2.0
799
+ while candidate < lo:
800
+ candidate *= 2.0
801
+ if candidate < lo or candidate > hi:
802
+ return bpm
803
+ return candidate
804
+
805
+
806
+ def _detect_bpm_faf(y, sr) -> Optional[int]:
807
+ """Detect BPM using librosa beat_track + octave correction (faf mode)."""
808
+ import librosa
809
+ import numpy as np
810
+
811
+ try:
812
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
813
+ val = float(np.atleast_1d(tempo)[0])
814
+ if val > 0:
815
+ return int(round(_octave_correct_bpm(val)))
816
+ except Exception:
817
+ pass
818
+ return None
819
+
820
+
821
+ def _detect_key_faf(y, sr) -> Optional[str]:
822
+ """Detect key using Krumhansl profile on chroma_cens (faf mode)."""
823
+ import librosa
824
+ import numpy as np
825
+
826
+ try:
827
+ y_harmonic = librosa.effects.harmonic(y, margin=2.0)
828
+ chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr)
829
+
830
+ # Energy-weighted average chroma
831
+ rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512)
832
+ rms_vec = rms[0]
833
+ min_len = min(chroma.shape[1], len(rms_vec))
834
+ chroma = chroma[:, :min_len]
835
+ rms_vec = rms_vec[:min_len]
836
+ weights = rms_vec / (rms_vec.sum() + 1e-10)
837
+ chroma_avg = (chroma * weights[None, :]).sum(axis=1)
838
+ s = chroma_avg.sum()
839
+ if s == 0:
840
+ return None
841
+ chroma_avg = chroma_avg / s
842
+
843
+ major_norm = np.array(_KEY_PROFILE_MAJOR)
844
+ major_norm = major_norm / major_norm.sum()
845
+ minor_norm = np.array(_KEY_PROFILE_MINOR)
846
+ minor_norm = minor_norm / minor_norm.sum()
847
+
848
+ best_corr = -2.0
849
+ best_key = "C major"
850
+ for shift in range(12):
851
+ rotated = np.roll(chroma_avg, -shift)
852
+ corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1])
853
+ if corr_maj > best_corr:
854
+ best_corr = corr_maj
855
+ best_key = f"{_PITCH_CLASSES[shift]} major"
856
+ corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1])
857
+ if corr_min > best_corr:
858
+ best_corr = corr_min
859
+ best_key = f"{_PITCH_CLASSES[shift]} minor"
860
+ return best_key
861
+ except Exception:
862
+ return None
863
+
864
+
865
+ def _detect_time_signature_faf() -> str:
866
+ """Faf mode returns hardcoded 4/4 (correct ~80%+ of the time)."""
867
+ return "4/4"
868
+
869
+
870
+ def _sanitize_tag(value: str) -> str:
871
+ """Normalize a tag value: NFKC normalize, strip invisible chars."""
872
+ value = unicodedata.normalize("NFKC", value)
873
+ value = (
874
+ value
875
+ .replace("", "").replace("￾", "")
876
+ .replace("​", "").replace("‌", "")
877
+ .replace("‍", "").replace("‎", "")
878
+ .replace("‏", "").replace("‪", "")
879
+ .replace("‬", "")
880
+ )
881
+ value = "".join(
882
+ c for c in value
883
+ if c in ("\n", "\r", "\t", " ") or unicodedata.category(c)[0] != "C"
884
+ )
885
+ return value.strip()
886
+
887
+
888
+ def _extract_metadata_from_tags(audio_path: Path) -> tuple:
889
+ """Extract (title, artist) from audio tags via mutagen, fallback to filename."""
890
+ title, artist = None, None
891
+ try:
892
+ import mutagen
893
+ mf = mutagen.File(str(audio_path))
894
+ if mf is not None and mf.tags is not None:
895
+ # ID3 (MP3, AIFF)
896
+ for key in ("TIT2",):
897
+ val = mf.tags.get(key)
898
+ if val:
899
+ title = _sanitize_tag(str(val))
900
+ break
901
+ for key in ("TPE1", "TPE2"):
902
+ val = mf.tags.get(key)
903
+ if val:
904
+ artist = _sanitize_tag(str(val))
905
+ break
906
+ # Vorbis (FLAC, OGG) and MP4 atoms
907
+ if title is None:
908
+ for key in ("title", "\xa9nam"):
909
+ vals = mf.tags.get(key)
910
+ if vals:
911
+ raw = str(vals[0]) if isinstance(vals, list) else str(vals)
912
+ title = _sanitize_tag(raw)
913
+ break
914
+ if artist is None:
915
+ for key in ("artist", "\xa9ART", "albumartist", "aART"):
916
+ vals = mf.tags.get(key)
917
+ if vals:
918
+ raw = str(vals[0]) if isinstance(vals, list) else str(vals)
919
+ artist = _sanitize_tag(raw)
920
+ break
921
+ except Exception:
922
+ pass
923
+
924
+ # Fallback to filename parsing
925
+ if not title:
926
+ stem = audio_path.stem
927
+ match = _FILENAME_RE.match(stem)
928
+ if match:
929
+ artist = artist or match.group(1).strip()
930
+ title = match.group(2).strip()
931
+ else:
932
+ title = stem.strip()
933
+
934
+ return title or audio_path.stem, artist or ""
935
+
936
+
937
+ def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
938
+ """Analyze an audio file and build a training caption.
939
+
940
+ Uses faf mode only (CPU, ~2-3s per file): librosa beat_track for BPM,
941
+ Krumhansl chroma for key, hardcoded 4/4 time signature.
942
+
943
+ Args:
944
+ audio_path: Path to the audio file.
945
+ mode: Analysis mode (only "faf" supported).
946
+
947
+ Returns:
948
+ Dict with keys: caption, bpm, key, signature, lyrics, title, artist.
949
+ """
950
+ import librosa
951
+ import numpy as np
952
+
953
+ audio_path = Path(audio_path)
954
+
955
+ # Load audio once, reuse for all detectors
956
+ try:
957
+ y, sr = librosa.load(str(audio_path), sr=None, mono=True)
958
+ # Trim silence + peak normalize
959
+ y_trimmed, _ = librosa.effects.trim(y, top_db=30)
960
+ if len(y_trimmed) >= sr:
961
+ y = y_trimmed
962
+ peak = np.max(np.abs(y))
963
+ if peak > 0:
964
+ y = y / peak
965
+ except Exception as exc:
966
+ logger.warning("Could not load audio for analysis: %s: %s", audio_path.name, exc)
967
+ title, artist = _extract_metadata_from_tags(audio_path)
968
+ return {
969
+ "caption": f"A track by {artist}" if artist else f"A track titled {title}",
970
+ "bpm": None, "key": None, "signature": "4/4",
971
+ "lyrics": "[Instrumental]", "title": title, "artist": artist,
972
+ }
973
+
974
+ bpm = _detect_bpm_faf(y, sr)
975
+ key = _detect_key_faf(y, sr)
976
+ signature = _detect_time_signature_faf()
977
+ title, artist = _extract_metadata_from_tags(audio_path)
978
+
979
+ # Build caption string for ACE-Step training
980
+ parts = ["A"]
981
+ if artist:
982
+ parts.append(f"track by {artist}")
983
+ else:
984
+ parts.append("track")
985
+ if bpm:
986
+ parts.append(f"at {bpm} BPM")
987
+ if key:
988
+ parts.append(f"in {key}")
989
+ parts.append(f"{signature} time")
990
+ caption = " ".join(parts)
991
+
992
+ lyrics = "[Instrumental]"
993
+
994
+ result = {
995
+ "caption": caption,
996
+ "bpm": bpm,
997
+ "key": key,
998
+ "signature": signature,
999
+ "lyrics": lyrics,
1000
+ "title": title,
1001
+ "artist": artist,
1002
+ }
1003
+
1004
+ logger.info("Auto-caption for %s: %s", audio_path.name, caption)
1005
+ return result
1006
+
1007
+
1008
+ def _write_caption_sidecar(audio_path: Path, analysis: Dict[str, Any]) -> Path:
1009
+ """Write analysis results as a .json sidecar next to the audio file."""
1010
+ sidecar_path = audio_path.with_suffix(".json")
1011
+ with open(sidecar_path, "w", encoding="utf-8") as f:
1012
+ json.dump(analysis, f, indent=2, ensure_ascii=False)
1013
+ logger.info("Wrote caption sidecar: %s", sidecar_path)
1014
+ return sidecar_path
1015
+
1016
+
1017
+ def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]:
1018
+ """Read an existing .json caption sidecar if it exists."""
1019
+ sidecar_path = audio_path.with_suffix(".json")
1020
+ if not sidecar_path.is_file():
1021
+ return None
1022
+ try:
1023
+ with open(sidecar_path, "r", encoding="utf-8") as f:
1024
+ return json.load(f)
1025
+ except Exception:
1026
+ return None
1027
+
1028
+
1029
  # ============================================================================
1030
  # PREPROCESSING (2-pass sequential)
1031
  # ============================================================================
 
1101
  lat_len = target_latents.shape[1]
1102
  att_mask = torch.ones(1, lat_len, device=device, dtype=dtype)
1103
 
1104
+ # Auto-caption: read existing sidecar or analyze
1105
+ sidecar = _read_caption_sidecar(af)
1106
+ if sidecar and sidecar.get("caption"):
1107
+ caption = sidecar["caption"]
1108
+ lyrics = sidecar.get("lyrics", "[Instrumental]")
1109
+ logger.info("Using existing caption for %s", af.name)
1110
+ else:
1111
+ try:
1112
+ analysis = analyze_and_caption(str(af))
1113
+ caption = analysis["caption"]
1114
+ lyrics = analysis.get("lyrics", "[Instrumental]")
1115
+ _write_caption_sidecar(af, analysis)
1116
+ except Exception as exc:
1117
+ logger.warning("Auto-caption failed for %s: %s, using filename", af.name, exc)
1118
+ caption = af.stem
1119
+ lyrics = "[Instrumental]"
1120
  text_prompt = caption
1121
 
1122
  with torch.no_grad():