Nekochu commited on
Commit
b38d0b1
·
1 Parent(s): 1d42836

add mid/sas analysis modes (Demucs + ensemble), auto-select by dataset size

Browse files
Files changed (1) hide show
  1. train_engine.py +989 -65
train_engine.py CHANGED
@@ -20,7 +20,9 @@ import math
20
  import os
21
  import random
22
  import re
 
23
  import sys
 
24
  import time
25
  import types
26
  import unicodedata
@@ -774,14 +776,43 @@ def _detect_max_duration(files: List[Path]) -> float:
774
 
775
 
776
  # ============================================================================
777
- # AUDIO ANALYSIS (ported from Side-Step, faf mode only -- CPU, ~2-3s/file)
778
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
779
 
780
- # Krumhansl key profile for single-profile key detection
781
- _KEY_PROFILE_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
782
- 2.52, 5.19, 2.39, 3.66, 2.29, 2.88]
783
- _KEY_PROFILE_MINOR = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
784
- 2.54, 4.75, 3.98, 2.69, 3.34, 3.17]
785
  _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
786
  "F#", "G", "G#", "A", "A#", "B"]
787
 
@@ -789,6 +820,141 @@ _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
789
  _FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
790
 
791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
792
  def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
793
  """Fold BPM into the musical sweet-spot range [lo, hi]."""
794
  if bpm <= 0:
@@ -803,68 +969,711 @@ def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> floa
803
  return candidate
804
 
805
 
806
- def _detect_bpm_faf(y, sr) -> Optional[int]:
807
- """Detect BPM using librosa beat_track + octave correction (faf mode)."""
 
 
 
808
  import librosa
809
  import numpy as np
810
 
 
 
 
811
  try:
812
- tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
813
- val = float(np.atleast_1d(tempo)[0])
814
- if val > 0:
815
- return int(round(_octave_correct_bpm(val)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  except Exception:
817
  pass
818
- return None
819
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
820
 
821
- def _detect_key_faf(y, sr) -> Optional[str]:
822
- """Detect key using Krumhansl profile on chroma_cens (faf mode)."""
 
 
 
 
 
 
 
 
 
 
 
 
823
  import librosa
824
  import numpy as np
825
 
826
  try:
827
- y_harmonic = librosa.effects.harmonic(y, margin=2.0)
828
- chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr)
829
-
830
- # Energy-weighted average chroma
831
- rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512)
832
- rms_vec = rms[0]
833
- min_len = min(chroma.shape[1], len(rms_vec))
834
- chroma = chroma[:, :min_len]
835
- rms_vec = rms_vec[:min_len]
836
- weights = rms_vec / (rms_vec.sum() + 1e-10)
837
- chroma_avg = (chroma * weights[None, :]).sum(axis=1)
838
- s = chroma_avg.sum()
839
- if s == 0:
840
- return None
841
- chroma_avg = chroma_avg / s
842
-
843
- major_norm = np.array(_KEY_PROFILE_MAJOR)
844
- major_norm = major_norm / major_norm.sum()
845
- minor_norm = np.array(_KEY_PROFILE_MINOR)
846
- minor_norm = minor_norm / minor_norm.sum()
847
-
848
- best_corr = -2.0
849
- best_key = "C major"
850
- for shift in range(12):
851
- rotated = np.roll(chroma_avg, -shift)
852
- corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1])
853
- if corr_maj > best_corr:
854
- best_corr = corr_maj
855
- best_key = f"{_PITCH_CLASSES[shift]} major"
856
- corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1])
857
- if corr_min > best_corr:
858
- best_corr = corr_min
859
- best_key = f"{_PITCH_CLASSES[shift]} minor"
860
- return best_key
861
- except Exception:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
862
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
863
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
864
 
865
- def _detect_time_signature_faf() -> str:
866
- """Faf mode returns hardcoded 4/4 (correct ~80%+ of the time)."""
867
- return "4/4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
 
870
  def _sanitize_tag(value: str) -> str:
@@ -934,24 +1743,39 @@ def _extract_metadata_from_tags(audio_path: Path) -> tuple:
934
  return title or audio_path.stem, artist or ""
935
 
936
 
937
- def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
 
 
 
 
938
  """Analyze an audio file and build a training caption.
939
 
940
- Uses faf mode only (CPU, ~2-3s per file): librosa beat_track for BPM,
941
- Krumhansl chroma for key, hardcoded 4/4 time signature.
 
 
 
 
 
942
 
943
  Args:
944
  audio_path: Path to the audio file.
945
- mode: Analysis mode (only "faf" supported).
 
946
 
947
  Returns:
948
- Dict with keys: caption, bpm, key, signature, lyrics, title, artist.
 
949
  """
950
  import librosa
951
  import numpy as np
952
 
953
  audio_path = Path(audio_path)
954
 
 
 
 
 
955
  # Load audio once, reuse for all detectors
956
  try:
957
  y, sr = librosa.load(str(audio_path), sr=None, mono=True)
@@ -969,11 +1793,76 @@ def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
969
  "caption": f"A track by {artist}" if artist else f"A track titled {title}",
970
  "bpm": None, "key": None, "signature": "4/4",
971
  "lyrics": "[Instrumental]", "title": title, "artist": artist,
 
972
  }
973
 
974
- bpm = _detect_bpm_faf(y, sr)
975
- key = _detect_key_faf(y, sr)
976
- signature = _detect_time_signature_faf()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977
  title, artist = _extract_metadata_from_tags(audio_path)
978
 
979
  # Build caption string for ACE-Step training
@@ -999,9 +1888,10 @@ def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
999
  "lyrics": lyrics,
1000
  "title": title,
1001
  "artist": artist,
 
1002
  }
1003
 
1004
- logger.info("Auto-caption for %s: %s", audio_path.name, caption)
1005
  return result
1006
 
1007
 
@@ -1106,15 +1996,49 @@ def preprocess_audio(
1106
  if sidecar and sidecar.get("caption"):
1107
  caption = sidecar["caption"]
1108
  lyrics = sidecar.get("lyrics", "[Instrumental]")
1109
- logger.info("Using existing caption for %s", af.name)
1110
  else:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1111
  try:
1112
- analysis = analyze_and_caption(str(af))
 
 
 
1113
  caption = analysis["caption"]
1114
  lyrics = analysis.get("lyrics", "[Instrumental]")
1115
  _write_caption_sidecar(af, analysis)
 
1116
  except Exception as exc:
1117
- logger.warning("Auto-caption failed for %s: %s, using filename", af.name, exc)
1118
  caption = af.stem
1119
  lyrics = "[Instrumental]"
1120
  text_prompt = caption
 
20
  import os
21
  import random
22
  import re
23
+ import shutil
24
  import sys
25
+ import tempfile
26
  import time
27
  import types
28
  import unicodedata
 
776
 
777
 
778
  # ============================================================================
779
+ # AUDIO ANALYSIS (ported from Side-Step -- faf / mid / sas modes)
780
  # ============================================================================
781
+ #
782
+ # faf ("Fast As F*ck") ~2-3 s/file - single-method, no Demucs
783
+ # mid ~12 s/file - 3-method ensemble, Demucs stems
784
+ # sas ("Smart/Slow As Sh*t") ~30 s/file - deep multi-technique + chunked
785
+ #
786
+ # Demucs on CPU is SLOW (~2-5 min/file). mid/sas are designed for GPU
787
+ # stem separation but will still work on CPU -- just much slower.
788
+ # ============================================================================
789
+
790
+ _ANALYSIS_MODES = ("faf", "mid", "sas")
791
+ _SAS_NUM_CHUNKS = 5
792
+ _SAS_CHUNK_SECONDS = 15 # seconds per analysis window
793
+
794
+ # Key profile families for multi-profile voting (mid / sas)
795
+ _KEY_PROFILES = {
796
+ "krumhansl": {
797
+ "major": [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
798
+ 2.52, 5.19, 2.39, 3.66, 2.29, 2.88],
799
+ "minor": [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
800
+ 2.54, 4.75, 3.98, 2.69, 3.34, 3.17],
801
+ },
802
+ "temperley": {
803
+ "major": [5.0, 2.0, 3.5, 2.0, 4.5, 4.0,
804
+ 2.0, 4.5, 2.0, 3.5, 1.5, 4.0],
805
+ "minor": [5.0, 2.0, 3.5, 4.5, 2.0, 3.5,
806
+ 2.0, 4.5, 3.5, 2.0, 1.5, 4.0],
807
+ },
808
+ "albrecht": {
809
+ "major": [0.238, 0.006, 0.111, 0.006, 0.137, 0.094,
810
+ 0.016, 0.214, 0.009, 0.080, 0.008, 0.081],
811
+ "minor": [0.220, 0.006, 0.104, 0.123, 0.019, 0.103,
812
+ 0.012, 0.214, 0.062, 0.022, 0.061, 0.052],
813
+ },
814
+ }
815
 
 
 
 
 
 
816
  _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
817
  "F#", "G", "G#", "A", "A#", "B"]
818
 
 
820
  _FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
821
 
822
 
823
+ # ---- Demucs stem separation (mid / sas) --------------------------------
824
+
825
+ def separate_stems(
826
+ audio_path: Path,
827
+ tmp_dir: Path,
828
+ device: str = "cpu",
829
+ ) -> Tuple[Path, Path]:
830
+ """Run Demucs HTDemucs and return (drums_path, harmonics_path).
831
+
832
+ Harmonics = bass + other stems summed. Vocals are discarded.
833
+
834
+ WARNING: On CPU this takes ~2-5 minutes per file.
835
+ """
836
+ import torchaudio
837
+ from demucs.pretrained import get_model
838
+ from demucs.apply import apply_model
839
+
840
+ torch_device = torch.device(device)
841
+
842
+ logger.info("Loading Demucs HTDemucs model on %s", device)
843
+ if device == "cpu":
844
+ logger.warning(
845
+ "Demucs on CPU is slow (~2-5 min per file). "
846
+ "Consider using 'faf' mode or running on a GPU machine."
847
+ )
848
+ model = get_model("htdemucs")
849
+ model.to(torch_device)
850
+ model.eval()
851
+
852
+ wav, sr = torchaudio.load(str(audio_path))
853
+
854
+ # Resample to model's expected rate (44100 Hz) if needed
855
+ if sr != model.samplerate:
856
+ wav = torchaudio.functional.resample(wav, sr, model.samplerate)
857
+ sr = model.samplerate
858
+
859
+ # HTDemucs requires stereo input
860
+ if wav.shape[0] == 1:
861
+ wav = wav.repeat(2, 1)
862
+
863
+ wav = wav.unsqueeze(0).to(torch_device)
864
+
865
+ logger.info("Separating stems for %s", audio_path.name)
866
+ with torch.no_grad():
867
+ sources = apply_model(model, wav, device=torch_device)
868
+
869
+ source_map = {name: i for i, name in enumerate(model.sources)}
870
+ drums = sources[0, source_map["drums"]].cpu()
871
+ bass = sources[0, source_map["bass"]].cpu()
872
+ other = sources[0, source_map["other"]].cpu()
873
+ harmonics = bass + other
874
+
875
+ drums_path = tmp_dir / "drums.wav"
876
+ harmonics_path = tmp_dir / "harmonics.wav"
877
+ torchaudio.save(str(drums_path), drums, sr)
878
+ torchaudio.save(str(harmonics_path), harmonics, sr)
879
+
880
+ del model, sources, wav, drums, bass, other, harmonics
881
+ gc.collect()
882
+
883
+ logger.info("Stems written: %s, %s", drums_path, harmonics_path)
884
+ return drums_path, harmonics_path
885
+
886
+
887
+ # ---- Chunk selection (sas mode) ----------------------------------------
888
+
889
+ def _select_chunks(
890
+ y, # np.ndarray
891
+ sr: int,
892
+ n_chunks: int = _SAS_NUM_CHUNKS,
893
+ chunk_sec: float = _SAS_CHUNK_SECONDS,
894
+ min_gap_sec: float = 10.0,
895
+ use_onset: bool = True,
896
+ ) -> list:
897
+ """Select the most informative audio chunks for sas analysis.
898
+
899
+ Energy-gated + spread: rank windows by onset density (or RMS),
900
+ discard below-median, then greedily pick chunks maximally spread apart.
901
+ """
902
+ import librosa
903
+ import numpy as np
904
+
905
+ chunk_samples = int(chunk_sec * sr)
906
+ hop_samples = chunk_samples // 2
907
+ if len(y) < chunk_samples:
908
+ return [y]
909
+
910
+ candidates = []
911
+ for start in range(0, len(y) - chunk_samples + 1, hop_samples):
912
+ window = y[start : start + chunk_samples]
913
+ if use_onset:
914
+ onset_env = librosa.onset.onset_strength(y=window, sr=sr)
915
+ score = float(np.mean(onset_env))
916
+ else:
917
+ score = float(np.sqrt(np.mean(window ** 2)))
918
+ candidates.append((start, score))
919
+
920
+ if not candidates:
921
+ return [y]
922
+
923
+ scores = np.array([s for _, s in candidates])
924
+ median_score = float(np.median(scores))
925
+ gated = [(start, score) for start, score in candidates if score >= median_score]
926
+ if not gated:
927
+ gated = candidates
928
+
929
+ gated.sort(key=lambda x: x[1], reverse=True)
930
+
931
+ min_gap_samples = int(min_gap_sec * sr)
932
+ selected_starts = []
933
+
934
+ for start, score in gated:
935
+ centre = start + chunk_samples // 2
936
+ too_close = any(
937
+ abs(centre - (s + chunk_samples // 2)) < min_gap_samples
938
+ for s in selected_starts
939
+ )
940
+ if not too_close:
941
+ selected_starts.append(start)
942
+ if len(selected_starts) >= n_chunks:
943
+ break
944
+
945
+ if len(selected_starts) < n_chunks:
946
+ for start, score in gated:
947
+ if start not in selected_starts:
948
+ selected_starts.append(start)
949
+ if len(selected_starts) >= n_chunks:
950
+ break
951
+
952
+ selected_starts.sort()
953
+ return [y[s : s + chunk_samples] for s in selected_starts]
954
+
955
+
956
+ # ---- BPM helpers --------------------------------------------------------
957
+
958
  def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
959
  """Fold BPM into the musical sweet-spot range [lo, hi]."""
960
  if bpm <= 0:
 
969
  return candidate
970
 
971
 
972
+ def _bpm_core_ensemble(y, sr) -> list:
973
+ """Run the 3-method BPM ensemble on a single audio buffer (mid/sas).
974
+
975
+ Returns a list of octave-corrected BPM estimates.
976
+ """
977
  import librosa
978
  import numpy as np
979
 
980
+ estimates = []
981
+
982
+ # Method A: beat_track
983
  try:
984
+ tempo_a, _ = librosa.beat.beat_track(y=y, sr=sr)
985
+ val_a = float(np.atleast_1d(tempo_a)[0])
986
+ if val_a > 0:
987
+ estimates.append(_octave_correct_bpm(val_a))
988
+ except Exception:
989
+ pass
990
+
991
+ # Method B: tempogram peak
992
+ try:
993
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
994
+ tempogram = librosa.feature.tempogram(onset_envelope=onset_env, sr=sr)
995
+ avg_tempogram = np.mean(tempogram, axis=1)
996
+ bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr)
997
+ valid = (bpm_axis >= 30) & (bpm_axis <= 300)
998
+ if np.any(valid):
999
+ masked = avg_tempogram.copy()
1000
+ masked[~valid] = 0
1001
+ peak_idx = np.argmax(masked)
1002
+ val_b = float(bpm_axis[peak_idx])
1003
+ if val_b > 0:
1004
+ estimates.append(_octave_correct_bpm(val_b))
1005
  except Exception:
1006
  pass
 
1007
 
1008
+ # Method C: onset autocorrelation
1009
+ try:
1010
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
1011
+ ac = librosa.autocorrelate(onset_env, max_size=len(onset_env))
1012
+ hop = 512
1013
+ min_lag = int(60.0 * sr / (300.0 * hop))
1014
+ max_lag = int(60.0 * sr / (30.0 * hop))
1015
+ max_lag = min(max_lag, len(ac) - 1)
1016
+ if min_lag < max_lag and max_lag > 0:
1017
+ segment = ac[min_lag:max_lag + 1]
1018
+ peak_offset = np.argmax(segment)
1019
+ peak_lag = min_lag + peak_offset
1020
+ if peak_lag > 0:
1021
+ val_c = 60.0 * sr / (peak_lag * hop)
1022
+ if val_c > 0:
1023
+ estimates.append(_octave_correct_bpm(val_c))
1024
+ except Exception:
1025
+ pass
1026
+
1027
+ return estimates
1028
+
1029
+
1030
+ def _bpm_consensus(estimates: list) -> Tuple[Optional[int], str]:
1031
+ """Find consensus BPM from a list of estimates + assign confidence."""
1032
+ import numpy as np
1033
+
1034
+ if not estimates:
1035
+ return None, "low"
1036
+
1037
+ estimates_arr = np.array(estimates)
1038
+ best_cluster = []
1039
+ for ref in estimates_arr:
1040
+ cluster = [e for e in estimates_arr
1041
+ if abs(e - ref) / max(ref, 1) < 0.08]
1042
+ if len(cluster) > len(best_cluster):
1043
+ best_cluster = cluster
1044
+
1045
+ consensus = float(np.median(best_cluster)) if best_cluster else estimates[0]
1046
+ bpm = int(round(consensus))
1047
+ if bpm <= 0:
1048
+ return None, "low"
1049
+
1050
+ n_agree = len(best_cluster)
1051
+ n_total = len(estimates)
1052
+ if n_total >= 6:
1053
+ # sas thresholds (many data points)
1054
+ if n_agree / n_total >= 0.7:
1055
+ confidence = "high"
1056
+ elif n_agree / n_total >= 0.4:
1057
+ confidence = "medium"
1058
+ else:
1059
+ confidence = "low"
1060
+ else:
1061
+ # mid thresholds
1062
+ if n_agree >= 3:
1063
+ confidence = "high"
1064
+ elif n_agree >= 2:
1065
+ confidence = "medium"
1066
+ else:
1067
+ confidence = "low"
1068
 
1069
+ return bpm, confidence
1070
+
1071
+
1072
+ # ---- Unified BPM detection ---------------------------------------------
1073
+
1074
+ def _detect_bpm(y, sr, mode: str = "faf") -> Tuple[Optional[int], str]:
1075
+ """Detect BPM with quality controlled by mode.
1076
+
1077
+ faf: Single beat_track + octave correction.
1078
+ mid: 3-method ensemble (beat_track + tempogram + onset-AC).
1079
+ sas: mid ensemble + PLP + multi-hop + chunked analysis.
1080
+
1081
+ Returns (bpm, confidence).
1082
+ """
1083
  import librosa
1084
  import numpy as np
1085
 
1086
  try:
1087
+ # faf: single method
1088
+ if mode == "faf":
1089
+ try:
1090
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
1091
+ val = float(np.atleast_1d(tempo)[0])
1092
+ if val > 0:
1093
+ bpm = int(round(_octave_correct_bpm(val)))
1094
+ logger.info("BPM faf: %d (raw: %.1f)", bpm, val)
1095
+ return bpm, "low"
1096
+ except Exception:
1097
+ pass
1098
+ return None, "low"
1099
+
1100
+ # mid: 3-method ensemble
1101
+ estimates = _bpm_core_ensemble(y, sr)
1102
+
1103
+ # sas: additional techniques
1104
+ ibi_cv = 0.5
1105
+ if mode == "sas":
1106
+ # PLP (Predominant Local Pulse)
1107
+ try:
1108
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
1109
+ pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr)
1110
+ plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse))
1111
+ hop = 512
1112
+ min_lag = int(60.0 * sr / (300.0 * hop))
1113
+ max_lag = int(60.0 * sr / (30.0 * hop))
1114
+ max_lag = min(max_lag, len(plp_ac) - 1)
1115
+ if min_lag < max_lag and max_lag > 0:
1116
+ seg = plp_ac[min_lag:max_lag + 1]
1117
+ peak_lag = min_lag + np.argmax(seg)
1118
+ if peak_lag > 0:
1119
+ plp_bpm = 60.0 * sr / (peak_lag * hop)
1120
+ if plp_bpm > 0:
1121
+ estimates.append(_octave_correct_bpm(plp_bpm))
1122
+ except Exception:
1123
+ pass
1124
+
1125
+ # Multi-hop beat_track (256, 1024)
1126
+ for extra_hop in (256, 1024):
1127
+ try:
1128
+ tempo_h, _ = librosa.beat.beat_track(y=y, sr=sr, hop_length=extra_hop)
1129
+ val_h = float(np.atleast_1d(tempo_h)[0])
1130
+ if val_h > 0:
1131
+ estimates.append(_octave_correct_bpm(val_h))
1132
+ except Exception:
1133
+ pass
1134
+
1135
+ # Chunked ensemble
1136
+ chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True)
1137
+ for chunk in chunks:
1138
+ chunk_estimates = _bpm_core_ensemble(chunk, sr)
1139
+ estimates.extend(chunk_estimates)
1140
+
1141
+ # IBI stability
1142
+ try:
1143
+ _, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
1144
+ if beat_frames is not None and len(beat_frames) > 4:
1145
+ beat_times = librosa.frames_to_time(beat_frames, sr=sr)
1146
+ ibis = np.diff(beat_times)
1147
+ ibi_cv = float(np.std(ibis) / (np.mean(ibis) + 1e-10))
1148
+ else:
1149
+ ibi_cv = 0.5
1150
+ except Exception:
1151
+ ibi_cv = 0.5
1152
+
1153
+ bpm, confidence = _bpm_consensus(estimates)
1154
+
1155
+ # sas: IBI stability can upgrade medium->high or downgrade
1156
+ if mode == "sas" and bpm is not None:
1157
+ if ibi_cv < 0.10 and confidence == "medium":
1158
+ confidence = "high"
1159
+ elif ibi_cv > 0.30 and confidence == "high":
1160
+ confidence = "medium"
1161
+
1162
+ logger.info(
1163
+ "BPM [%s]: %s (estimates=%s, conf=%s)",
1164
+ mode, bpm,
1165
+ [round(e, 1) for e in estimates[:10]],
1166
+ confidence,
1167
+ )
1168
+ return bpm, confidence
1169
+
1170
+ except Exception as exc:
1171
+ logger.warning("BPM detection failed: %s", exc)
1172
+ return None, "low"
1173
+
1174
+
1175
+ # ---- Key detection helpers ----------------------------------------------
1176
+
1177
+ def _best_key_for_profile(chroma_avg, major_profile, minor_profile):
1178
+ """Find the best key match for a single profile family.
1179
+
1180
+ Returns (key_label, correlation).
1181
+ """
1182
+ import numpy as np
1183
+
1184
+ major_norm = np.array(major_profile, dtype=float)
1185
+ major_norm = major_norm / major_norm.sum()
1186
+ minor_norm = np.array(minor_profile, dtype=float)
1187
+ minor_norm = minor_norm / minor_norm.sum()
1188
+
1189
+ best_corr = -2.0
1190
+ best_key = "C major"
1191
+
1192
+ for shift in range(12):
1193
+ rotated = np.roll(chroma_avg, -shift)
1194
+ corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1])
1195
+ if corr_maj > best_corr:
1196
+ best_corr = corr_maj
1197
+ best_key = f"{_PITCH_CLASSES[shift]} major"
1198
+ corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1])
1199
+ if corr_min > best_corr:
1200
+ best_corr = corr_min
1201
+ best_key = f"{_PITCH_CLASSES[shift]} minor"
1202
+
1203
+ return best_key, best_corr
1204
+
1205
+
1206
+ def _key_votes_from_chroma(chroma_avg, profiles=None) -> list:
1207
+ """Vote on key from a single chroma vector using specified profiles.
1208
+
1209
+ Returns list of (key_label, correlation) -- one per profile family.
1210
+ """
1211
+ if profiles is None:
1212
+ profiles = _KEY_PROFILES
1213
+ results = []
1214
+ for name, pf in profiles.items():
1215
+ key_label, corr = _best_key_for_profile(
1216
+ chroma_avg, pf["major"], pf["minor"],
1217
+ )
1218
+ results.append((key_label, corr))
1219
+ return results
1220
+
1221
+
1222
+ def _energy_weighted_chroma(chroma, y_harmonic):
1223
+ """Compute an energy-weighted average chroma vector.
1224
+
1225
+ Returns normalized chroma_avg or None if zero energy.
1226
+ """
1227
+ import librosa
1228
+ import numpy as np
1229
+
1230
+ rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512)
1231
+ rms_vec = rms[0]
1232
+ min_len = min(chroma.shape[1], len(rms_vec))
1233
+ chroma = chroma[:, :min_len]
1234
+ rms_vec = rms_vec[:min_len]
1235
+
1236
+ weights = rms_vec / (rms_vec.sum() + 1e-10)
1237
+ chroma_avg = (chroma * weights[None, :]).sum(axis=1)
1238
+
1239
+ s = chroma_avg.sum()
1240
+ if s == 0:
1241
  return None
1242
+ return chroma_avg / s
1243
+
1244
+
1245
+ # ---- Unified key detection ----------------------------------------------
1246
+
1247
+ def _detect_key(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]:
1248
+ """Detect musical key with quality controlled by mode.
1249
+
1250
+ faf: Single Krumhansl profile on chroma_cens.
1251
+ mid: 3-profile x energy-weighted chroma_cens x 8s segment voting.
1252
+ sas: mid + multi-chroma fusion + tonnetz + tuning correction +
1253
+ ending resolution + chunked voting.
1254
+
1255
+ Returns (key, confidence).
1256
+ """
1257
+ import librosa
1258
+ import numpy as np
1259
+ from collections import Counter
1260
+
1261
+ try:
1262
+ # Harmonic enhancement
1263
+ margin = 4.0 if mode != "faf" else 2.0
1264
+ y_harmonic = librosa.effects.harmonic(y, margin=margin)
1265
+
1266
+ # sas: tuning correction
1267
+ tuning = 0.0
1268
+ if mode == "sas":
1269
+ try:
1270
+ tuning = float(librosa.estimate_tuning(y=y_harmonic, sr=sr))
1271
+ except Exception:
1272
+ tuning = 0.0
1273
+
1274
+ # faf: single chroma, single profile
1275
+ if mode == "faf":
1276
+ chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr)
1277
+ chroma_avg = _energy_weighted_chroma(chroma, y_harmonic)
1278
+ if chroma_avg is None:
1279
+ return None, "low"
1280
+ kr = _KEY_PROFILES["krumhansl"]
1281
+ key_label, corr = _best_key_for_profile(
1282
+ chroma_avg, kr["major"], kr["minor"],
1283
+ )
1284
+ logger.info("Key faf: %s (corr=%.3f)", key_label, corr)
1285
+ return key_label, "low"
1286
+
1287
+ # mid / sas: multi-profile voting
1288
+ all_votes = []
1289
+ all_weights = []
1290
+
1291
+ if mode == "sas":
1292
+ chroma_types = {
1293
+ "cens": lambda: librosa.feature.chroma_cens(
1294
+ y=y_harmonic, sr=sr, tuning=tuning,
1295
+ ),
1296
+ "cqt": lambda: librosa.feature.chroma_cqt(
1297
+ y=y_harmonic, sr=sr, tuning=tuning,
1298
+ ),
1299
+ "stft": lambda: librosa.feature.chroma_stft(
1300
+ y=y_harmonic, sr=sr, tuning=tuning,
1301
+ ),
1302
+ }
1303
+ else:
1304
+ chroma_types = {
1305
+ "cens": lambda: librosa.feature.chroma_cens(
1306
+ y=y_harmonic, sr=sr,
1307
+ ),
1308
+ }
1309
+
1310
+ for chroma_name, chroma_fn in chroma_types.items():
1311
+ try:
1312
+ chroma = chroma_fn()
1313
+ except Exception:
1314
+ continue
1315
+
1316
+ chroma_avg = _energy_weighted_chroma(chroma, y_harmonic)
1317
+ if chroma_avg is None:
1318
+ continue
1319
+
1320
+ # Global multi-profile vote
1321
+ for key_label, corr in _key_votes_from_chroma(chroma_avg):
1322
+ all_votes.append(key_label)
1323
+ all_weights.append(1.0)
1324
+
1325
+ # Segment-based voting
1326
+ rms = librosa.feature.rms(
1327
+ y=y_harmonic, frame_length=2048, hop_length=512,
1328
+ )
1329
+ rms_vec = rms[0]
1330
+ min_len = min(chroma.shape[1], len(rms_vec))
1331
+ chroma_s = chroma[:, :min_len]
1332
+ rms_s = rms_vec[:min_len]
1333
+
1334
+ seg_frames = int(8.0 * sr / 512)
1335
+ n_segments = max(1, chroma_s.shape[1] // seg_frames)
1336
+
1337
+ for seg_i in range(n_segments):
1338
+ start = seg_i * seg_frames
1339
+ end = min(start + seg_frames, chroma_s.shape[1])
1340
+ seg_chroma = chroma_s[:, start:end]
1341
+ seg_w = rms_s[start:end]
1342
+
1343
+ w_sum = seg_w.sum()
1344
+ if w_sum < 1e-10:
1345
+ continue
1346
+
1347
+ seg_w_norm = seg_w / w_sum
1348
+ seg_avg = (seg_chroma * seg_w_norm[None, :]).sum(axis=1)
1349
+ s = seg_avg.sum()
1350
+ if s < 1e-10:
1351
+ continue
1352
+ seg_avg = seg_avg / s
1353
+
1354
+ for key_label, _ in _key_votes_from_chroma(seg_avg):
1355
+ all_votes.append(key_label)
1356
+ all_weights.append(1.0)
1357
+
1358
+ # sas-only extras
1359
+ if mode == "sas":
1360
+ # Tonnetz -- weighted vote for major/minor disambiguation
1361
+ try:
1362
+ tonnetz = librosa.feature.tonnetz(y=y_harmonic, sr=sr)
1363
+ tonnetz_avg = np.mean(tonnetz, axis=1)
1364
+ major_energy = float(np.sum(tonnetz_avg[4:6] ** 2))
1365
+ minor_energy = float(np.sum(tonnetz_avg[2:4] ** 2))
1366
+ tonnetz_ratio = major_energy / (minor_energy + 1e-10)
1367
+
1368
+ if all_votes:
1369
+ temp_counts = Counter(all_votes)
1370
+ leader = temp_counts.most_common(1)[0][0]
1371
+ leader_is_major = "major" in leader
1372
+ tonnetz_says_major = tonnetz_ratio > 1.0
1373
+
1374
+ if leader_is_major == tonnetz_says_major:
1375
+ all_votes.extend([leader] * 3)
1376
+ all_weights.extend([1.5] * 3)
1377
+ else:
1378
+ alt_mode = "minor" if leader_is_major else "major"
1379
+ chroma_cens = librosa.feature.chroma_cens(
1380
+ y=y_harmonic, sr=sr, tuning=tuning,
1381
+ )
1382
+ ca = _energy_weighted_chroma(chroma_cens, y_harmonic)
1383
+ if ca is not None:
1384
+ for name, pf in _KEY_PROFILES.items():
1385
+ prof = np.array(pf[alt_mode], dtype=float)
1386
+ prof_norm = prof / prof.sum()
1387
+ best_corr = -2.0
1388
+ best_k = ""
1389
+ for shift in range(12):
1390
+ rotated = np.roll(ca, -shift)
1391
+ c = float(np.corrcoef(rotated, prof_norm)[0, 1])
1392
+ if c > best_corr:
1393
+ best_corr = c
1394
+ best_k = f"{_PITCH_CLASSES[shift]} {alt_mode}"
1395
+ if best_k:
1396
+ all_votes.append(best_k)
1397
+ all_weights.append(1.0)
1398
+ except Exception:
1399
+ pass
1400
+
1401
+ # Ending resolution -- last ~5 s weighted extra
1402
+ try:
1403
+ end_samples = min(int(5.0 * sr), len(y_harmonic))
1404
+ y_end = y_harmonic[-end_samples:]
1405
+ chroma_end = librosa.feature.chroma_cens(
1406
+ y=y_end, sr=sr, tuning=tuning,
1407
+ )
1408
+ end_avg = np.mean(chroma_end, axis=1)
1409
+ s = end_avg.sum()
1410
+ if s > 1e-10:
1411
+ end_avg = end_avg / s
1412
+ for key_label, _ in _key_votes_from_chroma(end_avg):
1413
+ all_votes.append(key_label)
1414
+ all_weights.append(2.0)
1415
+ except Exception:
1416
+ pass
1417
+
1418
+ # Chunked voting
1419
+ chunks = _select_chunks(
1420
+ y_harmonic, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=False,
1421
+ )
1422
+ for chunk in chunks:
1423
+ try:
1424
+ ch_chroma = librosa.feature.chroma_cens(
1425
+ y=chunk, sr=sr, tuning=tuning,
1426
+ )
1427
+ ch_avg = _energy_weighted_chroma(ch_chroma, chunk)
1428
+ if ch_avg is not None:
1429
+ for key_label, _ in _key_votes_from_chroma(ch_avg):
1430
+ all_votes.append(key_label)
1431
+ all_weights.append(1.0)
1432
+ except Exception:
1433
+ pass
1434
+
1435
+ # Final weighted majority vote
1436
+ if not all_votes:
1437
+ return None, "low"
1438
+
1439
+ weighted_counts = {}
1440
+ for vote, w in zip(all_votes, all_weights):
1441
+ weighted_counts[vote] = weighted_counts.get(vote, 0.0) + w
1442
+
1443
+ best_key = max(weighted_counts, key=weighted_counts.get)
1444
+ total_weight = sum(all_weights)
1445
+ best_weight = weighted_counts[best_key]
1446
+ share = best_weight / total_weight
1447
+
1448
+ if share >= 0.55:
1449
+ confidence = "high"
1450
+ elif share >= 0.35:
1451
+ confidence = "medium"
1452
+ else:
1453
+ confidence = "low"
1454
+
1455
+ logger.info(
1456
+ "Key [%s]: %s (share=%.0f%%, votes=%d, conf=%s)",
1457
+ mode, best_key, share * 100, len(all_votes), confidence,
1458
+ )
1459
+ return best_key, confidence
1460
+
1461
+ except Exception as exc:
1462
+ logger.warning("Key detection failed: %s", exc)
1463
+ return None, "low"
1464
+
1465
+
1466
+ # ---- Time-signature helpers ---------------------------------------------
1467
+
1468
+ def _timesig_core_scores(y, sr) -> dict:
1469
+ """Compute 3-signal time-signature scores on a single buffer (mid/sas).
1470
 
1471
+ Returns dict mapping signature labels to raw scores.
1472
+ """
1473
+ import librosa
1474
+ import numpy as np
1475
+
1476
+ scores = {}
1477
+
1478
+ tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
1479
+ if beat_frames is None or len(beat_frames) < 8:
1480
+ return scores
1481
+
1482
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
1483
+ beat_strengths = onset_env[beat_frames[beat_frames < len(onset_env)]]
1484
+ if len(beat_strengths) < 8:
1485
+ return scores
1486
+
1487
+ # Signal 1: Accent pattern analysis
1488
+ for label, grouping in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
1489
+ if len(beat_strengths) < grouping * 2:
1490
+ scores[label] = 0.0
1491
+ continue
1492
+ usable = len(beat_strengths) - (len(beat_strengths) % grouping)
1493
+ grouped = beat_strengths[:usable].reshape(-1, grouping)
1494
+ downbeat_mean = float(np.mean(grouped[:, 0]))
1495
+ offbeat_mean = float(np.mean(grouped[:, 1:]))
1496
+ contrast = downbeat_mean / offbeat_mean if offbeat_mean > 0 else 1.0
1497
+ scores[label] = contrast
1498
+
1499
+ # Signal 2: Autocorrelation at meter periods
1500
+ hop = 512
1501
+ beat_times = librosa.frames_to_time(beat_frames, sr=sr)
1502
+ intervals = np.diff(beat_times)
1503
+ if len(intervals) > 0:
1504
+ median_interval = float(np.median(intervals))
1505
+ beat_period = int(round(median_interval * sr / hop))
1506
+ if beat_period > 0:
1507
+ ac = librosa.autocorrelate(onset_env, max_size=len(onset_env))
1508
+ for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
1509
+ period = beat_period * mult
1510
+ if period < len(ac):
1511
+ lo = max(0, period - 2)
1512
+ hi = min(len(ac), period + 3)
1513
+ ac_score = float(np.mean(ac[lo:hi]))
1514
+ if ac[0] > 0:
1515
+ ac_score /= float(ac[0])
1516
+ scores[label] = scores.get(label, 0.0) + ac_score
1517
+
1518
+ # Signal 3: Beat-strength variance ratio
1519
+ for label, grouping in [("3/4", 3), ("4/4", 4)]:
1520
+ usable = len(beat_strengths) - (len(beat_strengths) % grouping)
1521
+ if usable >= grouping * 2:
1522
+ grouped = beat_strengths[:usable].reshape(-1, grouping)
1523
+ row_vars = np.var(grouped, axis=1)
1524
+ scores[label] = scores.get(label, 0.0) + float(np.mean(row_vars))
1525
+
1526
+ return scores
1527
+
1528
+
1529
+ # ---- Unified time-signature detection -----------------------------------
1530
+
1531
+ def _detect_time_sig(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]:
1532
+ """Estimate time signature with quality controlled by mode.
1533
+
1534
+ faf: Hardcoded "4/4" (correct ~80%+ of the time).
1535
+ mid: Beat-sync accent + AC + variance + 4/4 prior.
1536
+ sas: mid signals + PLP periodicity + multi-band onset +
1537
+ tempogram harmonic ratios + chunked voting.
1538
+
1539
+ Returns (signature, confidence).
1540
+ """
1541
+ if mode == "faf":
1542
+ return "4/4", "low"
1543
+
1544
+ import librosa
1545
+ import numpy as np
1546
+
1547
+ try:
1548
+ # mid: core 3-signal scoring
1549
+ scores = _timesig_core_scores(y, sr)
1550
 
1551
+ # sas: additional techniques
1552
+ if mode == "sas":
1553
+ onset_env = librosa.onset.onset_strength(y=y, sr=sr)
1554
+
1555
+ # PLP periodicity
1556
+ try:
1557
+ pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr)
1558
+ plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse))
1559
+ tempo_est, _ = librosa.beat.beat_track(y=y, sr=sr)
1560
+ tempo_val = float(np.atleast_1d(tempo_est)[0])
1561
+ if tempo_val > 0:
1562
+ hop = 512
1563
+ bp = int(round(60.0 / tempo_val * sr / hop))
1564
+ if bp > 0:
1565
+ for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
1566
+ lag = bp * mult
1567
+ if lag < len(plp_ac):
1568
+ lo = max(0, lag - 2)
1569
+ hi = min(len(plp_ac), lag + 3)
1570
+ s = float(np.mean(plp_ac[lo:hi]))
1571
+ if plp_ac[0] > 0:
1572
+ s /= float(plp_ac[0])
1573
+ scores[label] = scores.get(label, 0.0) + s
1574
+ except Exception:
1575
+ pass
1576
+
1577
+ # Multi-band onset analysis (low/mid/high)
1578
+ try:
1579
+ S = np.abs(librosa.stft(y))
1580
+ n_bins = S.shape[0]
1581
+ third = n_bins // 3
1582
+ bands = {
1583
+ "low": S[:third, :],
1584
+ "mid_band": S[third:2*third, :],
1585
+ "high": S[2*third:, :],
1586
+ }
1587
+ for band_name, band_S in bands.items():
1588
+ band_onset = librosa.onset.onset_strength(S=band_S, sr=sr)
1589
+ band_ac = librosa.autocorrelate(
1590
+ band_onset, max_size=len(band_onset),
1591
+ )
1592
+ tempo_val2 = float(np.atleast_1d(tempo_est)[0])
1593
+ if tempo_val2 > 0:
1594
+ hop = 512
1595
+ bp2 = int(round(60.0 / tempo_val2 * sr / hop))
1596
+ if bp2 > 0 and band_ac[0] > 0:
1597
+ for label, mult in [("3/4", 3), ("4/4", 4)]:
1598
+ lag = bp2 * mult
1599
+ if lag < len(band_ac):
1600
+ lo = max(0, lag - 2)
1601
+ hi = min(len(band_ac), lag + 3)
1602
+ s = float(np.mean(band_ac[lo:hi]))
1603
+ s /= float(band_ac[0])
1604
+ w = 1.5 if band_name == "low" else 1.0
1605
+ scores[label] = scores.get(label, 0.0) + s * w
1606
+ except Exception:
1607
+ pass
1608
+
1609
+ # Tempogram harmonic ratios
1610
+ try:
1611
+ tempogram = librosa.feature.tempogram(
1612
+ onset_envelope=onset_env, sr=sr,
1613
+ )
1614
+ avg_tg = np.mean(tempogram, axis=1)
1615
+ bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr)
1616
+ if tempo_val > 0:
1617
+ for mult_label, t_mult in [("duple", 2.0), ("triple", 3.0)]:
1618
+ target_bpm = tempo_val * t_mult
1619
+ if target_bpm < 300:
1620
+ idx = np.argmin(np.abs(bpm_axis - target_bpm))
1621
+ energy = float(avg_tg[idx])
1622
+ base_idx = np.argmin(np.abs(bpm_axis - tempo_val))
1623
+ base_energy = float(avg_tg[base_idx]) + 1e-10
1624
+ ratio = energy / base_energy
1625
+ if t_mult == 2.0:
1626
+ scores["4/4"] = scores.get("4/4", 0.0) + ratio
1627
+ else:
1628
+ scores["3/4"] = scores.get("3/4", 0.0) + ratio
1629
+ except Exception:
1630
+ pass
1631
+
1632
+ # Chunked voting
1633
+ chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True)
1634
+ chunk_votes = []
1635
+ for chunk in chunks:
1636
+ cs = _timesig_core_scores(chunk, sr)
1637
+ if cs:
1638
+ cs["4/4"] = cs.get("4/4", 0.0) * 1.15
1639
+ best_c = max(cs, key=cs.get)
1640
+ chunk_votes.append(best_c)
1641
+ for vote in chunk_votes:
1642
+ scores[vote] = scores.get(vote, 0.0) + 1.0
1643
+
1644
+ # Bayesian prior: bias toward 4/4
1645
+ scores["4/4"] = scores.get("4/4", 0.0) * 1.15
1646
+
1647
+ if not scores:
1648
+ return "4/4", "low"
1649
+
1650
+ best = max(scores, key=scores.get)
1651
+
1652
+ # Confidence: margin between top 2
1653
+ sorted_scores = sorted(scores.values(), reverse=True)
1654
+ if len(sorted_scores) >= 2 and sorted_scores[1] > 0:
1655
+ margin = sorted_scores[0] / sorted_scores[1]
1656
+ else:
1657
+ margin = 1.0
1658
+
1659
+ if margin > 1.4:
1660
+ confidence = "high"
1661
+ elif margin > 1.15:
1662
+ confidence = "medium"
1663
+ else:
1664
+ confidence = "low"
1665
+
1666
+ logger.info(
1667
+ "TimeSig [%s]: %s (scores=%s, margin=%.2f, conf=%s)",
1668
+ mode, best,
1669
+ {k: round(v, 3) for k, v in scores.items()},
1670
+ margin, confidence,
1671
+ )
1672
+ return best, confidence
1673
+
1674
+ except Exception as exc:
1675
+ logger.warning("Time signature detection failed: %s", exc)
1676
+ return "4/4", "low"
1677
 
1678
 
1679
  def _sanitize_tag(value: str) -> str:
 
1743
  return title or audio_path.stem, artist or ""
1744
 
1745
 
1746
+ def analyze_and_caption(
1747
+ audio_path: str,
1748
+ mode: str = "faf",
1749
+ device: str = "cpu",
1750
+ ) -> Dict[str, Any]:
1751
  """Analyze an audio file and build a training caption.
1752
 
1753
+ Supports three quality modes:
1754
+ faf - CPU, ~2-3s/file. Single-method detection on raw mix.
1755
+ mid - ~12s/file. Demucs stems + 3-method ensemble.
1756
+ sas - ~30s/file. Deep multi-technique + chunked analysis.
1757
+
1758
+ For mid/sas, Demucs separates drums and harmonics stems first.
1759
+ On CPU, Demucs adds ~2-5 minutes per file.
1760
 
1761
  Args:
1762
  audio_path: Path to the audio file.
1763
+ mode: Analysis mode ("faf", "mid", or "sas").
1764
+ device: Torch device for Demucs ("cpu").
1765
 
1766
  Returns:
1767
+ Dict with keys: caption, bpm, key, signature, lyrics, title, artist,
1768
+ confidence (dict of per-field confidence levels).
1769
  """
1770
  import librosa
1771
  import numpy as np
1772
 
1773
  audio_path = Path(audio_path)
1774
 
1775
+ if mode not in _ANALYSIS_MODES:
1776
+ logger.warning("Unknown analysis mode '%s', falling back to 'faf'", mode)
1777
+ mode = "faf"
1778
+
1779
  # Load audio once, reuse for all detectors
1780
  try:
1781
  y, sr = librosa.load(str(audio_path), sr=None, mono=True)
 
1793
  "caption": f"A track by {artist}" if artist else f"A track titled {title}",
1794
  "bpm": None, "key": None, "signature": "4/4",
1795
  "lyrics": "[Instrumental]", "title": title, "artist": artist,
1796
+ "confidence": {},
1797
  }
1798
 
1799
+ confidence = {}
1800
+ tmp_dir = None
1801
+
1802
+ try:
1803
+ if mode in ("mid", "sas"):
1804
+ # Demucs stem separation -- run BPM/timesig on drums,
1805
+ # key detection on harmonics
1806
+ tmp_dir = Path(tempfile.mkdtemp(prefix="ace_analysis_"))
1807
+ try:
1808
+ drums_path, harmonics_path = separate_stems(
1809
+ audio_path, tmp_dir, device=device,
1810
+ )
1811
+ # Load separated stems for analysis
1812
+ y_drums, sr_drums = librosa.load(
1813
+ str(drums_path), sr=None, mono=True,
1814
+ )
1815
+ y_harmonics, sr_harmonics = librosa.load(
1816
+ str(harmonics_path), sr=None, mono=True,
1817
+ )
1818
+ # Preprocess stems
1819
+ y_drums_trimmed, _ = librosa.effects.trim(y_drums, top_db=30)
1820
+ if len(y_drums_trimmed) >= sr_drums:
1821
+ y_drums = y_drums_trimmed
1822
+ peak_d = np.max(np.abs(y_drums))
1823
+ if peak_d > 0:
1824
+ y_drums = y_drums / peak_d
1825
+
1826
+ y_harm_trimmed, _ = librosa.effects.trim(y_harmonics, top_db=30)
1827
+ if len(y_harm_trimmed) >= sr_harmonics:
1828
+ y_harmonics = y_harm_trimmed
1829
+ peak_h = np.max(np.abs(y_harmonics))
1830
+ if peak_h > 0:
1831
+ y_harmonics = y_harmonics / peak_h
1832
+
1833
+ # BPM + time sig on drums stem
1834
+ bpm, bpm_conf = _detect_bpm(y_drums, sr_drums, mode)
1835
+ signature, sig_conf = _detect_time_sig(y_drums, sr_drums, mode)
1836
+ # Key on harmonics stem
1837
+ key, key_conf = _detect_key(y_harmonics, sr_harmonics, mode)
1838
+
1839
+ confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
1840
+
1841
+ except Exception as exc:
1842
+ logger.warning(
1843
+ "Demucs separation failed for %s: %s -- "
1844
+ "falling back to analysis on raw mix",
1845
+ audio_path.name, exc,
1846
+ )
1847
+ # Fallback: run detectors on raw mix
1848
+ bpm, bpm_conf = _detect_bpm(y, sr, mode)
1849
+ key, key_conf = _detect_key(y, sr, mode)
1850
+ signature, sig_conf = _detect_time_sig(y, sr, mode)
1851
+ confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
1852
+ else:
1853
+ # faf: all detectors on raw mix
1854
+ bpm, bpm_conf = _detect_bpm(y, sr, mode)
1855
+ key, key_conf = _detect_key(y, sr, mode)
1856
+ signature, sig_conf = _detect_time_sig(y, sr, mode)
1857
+ confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
1858
+
1859
+ finally:
1860
+ if tmp_dir is not None:
1861
+ try:
1862
+ shutil.rmtree(tmp_dir)
1863
+ except OSError as exc:
1864
+ logger.debug("Could not clean temp dir %s: %s", tmp_dir, exc)
1865
+
1866
  title, artist = _extract_metadata_from_tags(audio_path)
1867
 
1868
  # Build caption string for ACE-Step training
 
1888
  "lyrics": lyrics,
1889
  "title": title,
1890
  "artist": artist,
1891
+ "confidence": confidence,
1892
  }
1893
 
1894
+ logger.info("Auto-caption [%s] for %s: %s", mode, audio_path.name, caption)
1895
  return result
1896
 
1897
 
 
1996
  if sidecar and sidecar.get("caption"):
1997
  caption = sidecar["caption"]
1998
  lyrics = sidecar.get("lyrics", "[Instrumental]")
1999
+ logger.info("[Caption] %s: using existing sidecar", af.name)
2000
  else:
2001
+ # Auto-select analysis mode based on dataset size
2002
+ if total <= 20:
2003
+ analysis_mode = "sas"
2004
+ elif total <= 100:
2005
+ analysis_mode = "mid"
2006
+ else:
2007
+ analysis_mode = "faf"
2008
+
2009
+ # Log mode selection with reasoning (first file only)
2010
+ if i == 0:
2011
+ _MODE_DESC = {
2012
+ "faf": "fast, ~3s/file",
2013
+ "mid": "balanced, ~12s/file",
2014
+ "sas": "best quality, ~30s/file on GPU, slower on CPU",
2015
+ }
2016
+ logger.info(
2017
+ "[Analysis] Mode auto-selected: '%s' (%s) "
2018
+ "for %d files (<=20: sas, 21-100: mid, 100+: faf)",
2019
+ analysis_mode, _MODE_DESC[analysis_mode], total,
2020
+ )
2021
+ if analysis_mode in ("mid", "sas") and device == "cpu":
2022
+ logger.warning(
2023
+ "[Analysis] Mode '%s' uses Demucs stem separation "
2024
+ "which is SLOW on CPU (~2-5 min/file). "
2025
+ "Total estimated time: ~%d-%d min for %d files. "
2026
+ "Use 'faf' mode or a GPU machine for faster processing.",
2027
+ analysis_mode,
2028
+ total * 2, total * 5, total,
2029
+ )
2030
+
2031
  try:
2032
+ logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
2033
+ analysis = analyze_and_caption(
2034
+ str(af), mode=analysis_mode, device=device,
2035
+ )
2036
  caption = analysis["caption"]
2037
  lyrics = analysis.get("lyrics", "[Instrumental]")
2038
  _write_caption_sidecar(af, analysis)
2039
+ logger.info("[Caption] %s: %s", af.name, caption)
2040
  except Exception as exc:
2041
+ logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
2042
  caption = af.stem
2043
  lyrics = "[Instrumental]"
2044
  text_prompt = caption