Spaces:
Running
Running
add mid/sas analysis modes (Demucs + ensemble), auto-select by dataset size
Browse files- train_engine.py +989 -65
train_engine.py
CHANGED
|
@@ -20,7 +20,9 @@ import math
|
|
| 20 |
import os
|
| 21 |
import random
|
| 22 |
import re
|
|
|
|
| 23 |
import sys
|
|
|
|
| 24 |
import time
|
| 25 |
import types
|
| 26 |
import unicodedata
|
|
@@ -774,14 +776,43 @@ def _detect_max_duration(files: List[Path]) -> float:
|
|
| 774 |
|
| 775 |
|
| 776 |
# ============================================================================
|
| 777 |
-
# AUDIO ANALYSIS (ported from Side-Step
|
| 778 |
# ============================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
-
# Krumhansl key profile for single-profile key detection
|
| 781 |
-
_KEY_PROFILE_MAJOR = [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
|
| 782 |
-
2.52, 5.19, 2.39, 3.66, 2.29, 2.88]
|
| 783 |
-
_KEY_PROFILE_MINOR = [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
|
| 784 |
-
2.54, 4.75, 3.98, 2.69, 3.34, 3.17]
|
| 785 |
_PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
|
| 786 |
"F#", "G", "G#", "A", "A#", "B"]
|
| 787 |
|
|
@@ -789,6 +820,141 @@ _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
|
|
| 789 |
_FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
|
| 790 |
|
| 791 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
|
| 793 |
"""Fold BPM into the musical sweet-spot range [lo, hi]."""
|
| 794 |
if bpm <= 0:
|
|
@@ -803,68 +969,711 @@ def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> floa
|
|
| 803 |
return candidate
|
| 804 |
|
| 805 |
|
| 806 |
-
def
|
| 807 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 808 |
import librosa
|
| 809 |
import numpy as np
|
| 810 |
|
|
|
|
|
|
|
|
|
|
| 811 |
try:
|
| 812 |
-
|
| 813 |
-
|
| 814 |
-
if
|
| 815 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
except Exception:
|
| 817 |
pass
|
| 818 |
-
return None
|
| 819 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 820 |
|
| 821 |
-
|
| 822 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 823 |
import librosa
|
| 824 |
import numpy as np
|
| 825 |
|
| 826 |
try:
|
| 827 |
-
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
| 858 |
-
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 862 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 863 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 864 |
|
| 865 |
-
|
| 866 |
-
|
| 867 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
|
| 870 |
def _sanitize_tag(value: str) -> str:
|
|
@@ -934,24 +1743,39 @@ def _extract_metadata_from_tags(audio_path: Path) -> tuple:
|
|
| 934 |
return title or audio_path.stem, artist or ""
|
| 935 |
|
| 936 |
|
| 937 |
-
def analyze_and_caption(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 938 |
"""Analyze an audio file and build a training caption.
|
| 939 |
|
| 940 |
-
|
| 941 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 942 |
|
| 943 |
Args:
|
| 944 |
audio_path: Path to the audio file.
|
| 945 |
-
mode: Analysis mode (
|
|
|
|
| 946 |
|
| 947 |
Returns:
|
| 948 |
-
Dict with keys: caption, bpm, key, signature, lyrics, title, artist
|
|
|
|
| 949 |
"""
|
| 950 |
import librosa
|
| 951 |
import numpy as np
|
| 952 |
|
| 953 |
audio_path = Path(audio_path)
|
| 954 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 955 |
# Load audio once, reuse for all detectors
|
| 956 |
try:
|
| 957 |
y, sr = librosa.load(str(audio_path), sr=None, mono=True)
|
|
@@ -969,11 +1793,76 @@ def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
|
|
| 969 |
"caption": f"A track by {artist}" if artist else f"A track titled {title}",
|
| 970 |
"bpm": None, "key": None, "signature": "4/4",
|
| 971 |
"lyrics": "[Instrumental]", "title": title, "artist": artist,
|
|
|
|
| 972 |
}
|
| 973 |
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 977 |
title, artist = _extract_metadata_from_tags(audio_path)
|
| 978 |
|
| 979 |
# Build caption string for ACE-Step training
|
|
@@ -999,9 +1888,10 @@ def analyze_and_caption(audio_path: str, mode: str = "faf") -> Dict[str, Any]:
|
|
| 999 |
"lyrics": lyrics,
|
| 1000 |
"title": title,
|
| 1001 |
"artist": artist,
|
|
|
|
| 1002 |
}
|
| 1003 |
|
| 1004 |
-
logger.info("Auto-caption for %s: %s", audio_path.name, caption)
|
| 1005 |
return result
|
| 1006 |
|
| 1007 |
|
|
@@ -1106,15 +1996,49 @@ def preprocess_audio(
|
|
| 1106 |
if sidecar and sidecar.get("caption"):
|
| 1107 |
caption = sidecar["caption"]
|
| 1108 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 1109 |
-
logger.info("
|
| 1110 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1111 |
try:
|
| 1112 |
-
|
|
|
|
|
|
|
|
|
|
| 1113 |
caption = analysis["caption"]
|
| 1114 |
lyrics = analysis.get("lyrics", "[Instrumental]")
|
| 1115 |
_write_caption_sidecar(af, analysis)
|
|
|
|
| 1116 |
except Exception as exc:
|
| 1117 |
-
logger.warning("
|
| 1118 |
caption = af.stem
|
| 1119 |
lyrics = "[Instrumental]"
|
| 1120 |
text_prompt = caption
|
|
|
|
| 20 |
import os
|
| 21 |
import random
|
| 22 |
import re
|
| 23 |
+
import shutil
|
| 24 |
import sys
|
| 25 |
+
import tempfile
|
| 26 |
import time
|
| 27 |
import types
|
| 28 |
import unicodedata
|
|
|
|
| 776 |
|
| 777 |
|
| 778 |
# ============================================================================
|
| 779 |
+
# AUDIO ANALYSIS (ported from Side-Step -- faf / mid / sas modes)
|
| 780 |
# ============================================================================
|
| 781 |
+
#
|
| 782 |
+
# faf ("Fast As F*ck") ~2-3 s/file - single-method, no Demucs
|
| 783 |
+
# mid ~12 s/file - 3-method ensemble, Demucs stems
|
| 784 |
+
# sas ("Smart/Slow As Sh*t") ~30 s/file - deep multi-technique + chunked
|
| 785 |
+
#
|
| 786 |
+
# Demucs on CPU is SLOW (~2-5 min/file). mid/sas are designed for GPU
|
| 787 |
+
# stem separation but will still work on CPU -- just much slower.
|
| 788 |
+
# ============================================================================
|
| 789 |
+
|
| 790 |
+
_ANALYSIS_MODES = ("faf", "mid", "sas")
|
| 791 |
+
_SAS_NUM_CHUNKS = 5
|
| 792 |
+
_SAS_CHUNK_SECONDS = 15 # seconds per analysis window
|
| 793 |
+
|
| 794 |
+
# Key profile families for multi-profile voting (mid / sas)
|
| 795 |
+
_KEY_PROFILES = {
|
| 796 |
+
"krumhansl": {
|
| 797 |
+
"major": [6.35, 2.23, 3.48, 2.33, 4.38, 4.09,
|
| 798 |
+
2.52, 5.19, 2.39, 3.66, 2.29, 2.88],
|
| 799 |
+
"minor": [6.33, 2.68, 3.52, 5.38, 2.60, 3.53,
|
| 800 |
+
2.54, 4.75, 3.98, 2.69, 3.34, 3.17],
|
| 801 |
+
},
|
| 802 |
+
"temperley": {
|
| 803 |
+
"major": [5.0, 2.0, 3.5, 2.0, 4.5, 4.0,
|
| 804 |
+
2.0, 4.5, 2.0, 3.5, 1.5, 4.0],
|
| 805 |
+
"minor": [5.0, 2.0, 3.5, 4.5, 2.0, 3.5,
|
| 806 |
+
2.0, 4.5, 3.5, 2.0, 1.5, 4.0],
|
| 807 |
+
},
|
| 808 |
+
"albrecht": {
|
| 809 |
+
"major": [0.238, 0.006, 0.111, 0.006, 0.137, 0.094,
|
| 810 |
+
0.016, 0.214, 0.009, 0.080, 0.008, 0.081],
|
| 811 |
+
"minor": [0.220, 0.006, 0.104, 0.123, 0.019, 0.103,
|
| 812 |
+
0.012, 0.214, 0.062, 0.022, 0.061, 0.052],
|
| 813 |
+
},
|
| 814 |
+
}
|
| 815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
_PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F",
|
| 817 |
"F#", "G", "G#", "A", "A#", "B"]
|
| 818 |
|
|
|
|
| 820 |
_FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$")
|
| 821 |
|
| 822 |
|
| 823 |
+
# ---- Demucs stem separation (mid / sas) --------------------------------
|
| 824 |
+
|
| 825 |
+
def separate_stems(
|
| 826 |
+
audio_path: Path,
|
| 827 |
+
tmp_dir: Path,
|
| 828 |
+
device: str = "cpu",
|
| 829 |
+
) -> Tuple[Path, Path]:
|
| 830 |
+
"""Run Demucs HTDemucs and return (drums_path, harmonics_path).
|
| 831 |
+
|
| 832 |
+
Harmonics = bass + other stems summed. Vocals are discarded.
|
| 833 |
+
|
| 834 |
+
WARNING: On CPU this takes ~2-5 minutes per file.
|
| 835 |
+
"""
|
| 836 |
+
import torchaudio
|
| 837 |
+
from demucs.pretrained import get_model
|
| 838 |
+
from demucs.apply import apply_model
|
| 839 |
+
|
| 840 |
+
torch_device = torch.device(device)
|
| 841 |
+
|
| 842 |
+
logger.info("Loading Demucs HTDemucs model on %s", device)
|
| 843 |
+
if device == "cpu":
|
| 844 |
+
logger.warning(
|
| 845 |
+
"Demucs on CPU is slow (~2-5 min per file). "
|
| 846 |
+
"Consider using 'faf' mode or running on a GPU machine."
|
| 847 |
+
)
|
| 848 |
+
model = get_model("htdemucs")
|
| 849 |
+
model.to(torch_device)
|
| 850 |
+
model.eval()
|
| 851 |
+
|
| 852 |
+
wav, sr = torchaudio.load(str(audio_path))
|
| 853 |
+
|
| 854 |
+
# Resample to model's expected rate (44100 Hz) if needed
|
| 855 |
+
if sr != model.samplerate:
|
| 856 |
+
wav = torchaudio.functional.resample(wav, sr, model.samplerate)
|
| 857 |
+
sr = model.samplerate
|
| 858 |
+
|
| 859 |
+
# HTDemucs requires stereo input
|
| 860 |
+
if wav.shape[0] == 1:
|
| 861 |
+
wav = wav.repeat(2, 1)
|
| 862 |
+
|
| 863 |
+
wav = wav.unsqueeze(0).to(torch_device)
|
| 864 |
+
|
| 865 |
+
logger.info("Separating stems for %s", audio_path.name)
|
| 866 |
+
with torch.no_grad():
|
| 867 |
+
sources = apply_model(model, wav, device=torch_device)
|
| 868 |
+
|
| 869 |
+
source_map = {name: i for i, name in enumerate(model.sources)}
|
| 870 |
+
drums = sources[0, source_map["drums"]].cpu()
|
| 871 |
+
bass = sources[0, source_map["bass"]].cpu()
|
| 872 |
+
other = sources[0, source_map["other"]].cpu()
|
| 873 |
+
harmonics = bass + other
|
| 874 |
+
|
| 875 |
+
drums_path = tmp_dir / "drums.wav"
|
| 876 |
+
harmonics_path = tmp_dir / "harmonics.wav"
|
| 877 |
+
torchaudio.save(str(drums_path), drums, sr)
|
| 878 |
+
torchaudio.save(str(harmonics_path), harmonics, sr)
|
| 879 |
+
|
| 880 |
+
del model, sources, wav, drums, bass, other, harmonics
|
| 881 |
+
gc.collect()
|
| 882 |
+
|
| 883 |
+
logger.info("Stems written: %s, %s", drums_path, harmonics_path)
|
| 884 |
+
return drums_path, harmonics_path
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
# ---- Chunk selection (sas mode) ----------------------------------------
|
| 888 |
+
|
| 889 |
+
def _select_chunks(
|
| 890 |
+
y, # np.ndarray
|
| 891 |
+
sr: int,
|
| 892 |
+
n_chunks: int = _SAS_NUM_CHUNKS,
|
| 893 |
+
chunk_sec: float = _SAS_CHUNK_SECONDS,
|
| 894 |
+
min_gap_sec: float = 10.0,
|
| 895 |
+
use_onset: bool = True,
|
| 896 |
+
) -> list:
|
| 897 |
+
"""Select the most informative audio chunks for sas analysis.
|
| 898 |
+
|
| 899 |
+
Energy-gated + spread: rank windows by onset density (or RMS),
|
| 900 |
+
discard below-median, then greedily pick chunks maximally spread apart.
|
| 901 |
+
"""
|
| 902 |
+
import librosa
|
| 903 |
+
import numpy as np
|
| 904 |
+
|
| 905 |
+
chunk_samples = int(chunk_sec * sr)
|
| 906 |
+
hop_samples = chunk_samples // 2
|
| 907 |
+
if len(y) < chunk_samples:
|
| 908 |
+
return [y]
|
| 909 |
+
|
| 910 |
+
candidates = []
|
| 911 |
+
for start in range(0, len(y) - chunk_samples + 1, hop_samples):
|
| 912 |
+
window = y[start : start + chunk_samples]
|
| 913 |
+
if use_onset:
|
| 914 |
+
onset_env = librosa.onset.onset_strength(y=window, sr=sr)
|
| 915 |
+
score = float(np.mean(onset_env))
|
| 916 |
+
else:
|
| 917 |
+
score = float(np.sqrt(np.mean(window ** 2)))
|
| 918 |
+
candidates.append((start, score))
|
| 919 |
+
|
| 920 |
+
if not candidates:
|
| 921 |
+
return [y]
|
| 922 |
+
|
| 923 |
+
scores = np.array([s for _, s in candidates])
|
| 924 |
+
median_score = float(np.median(scores))
|
| 925 |
+
gated = [(start, score) for start, score in candidates if score >= median_score]
|
| 926 |
+
if not gated:
|
| 927 |
+
gated = candidates
|
| 928 |
+
|
| 929 |
+
gated.sort(key=lambda x: x[1], reverse=True)
|
| 930 |
+
|
| 931 |
+
min_gap_samples = int(min_gap_sec * sr)
|
| 932 |
+
selected_starts = []
|
| 933 |
+
|
| 934 |
+
for start, score in gated:
|
| 935 |
+
centre = start + chunk_samples // 2
|
| 936 |
+
too_close = any(
|
| 937 |
+
abs(centre - (s + chunk_samples // 2)) < min_gap_samples
|
| 938 |
+
for s in selected_starts
|
| 939 |
+
)
|
| 940 |
+
if not too_close:
|
| 941 |
+
selected_starts.append(start)
|
| 942 |
+
if len(selected_starts) >= n_chunks:
|
| 943 |
+
break
|
| 944 |
+
|
| 945 |
+
if len(selected_starts) < n_chunks:
|
| 946 |
+
for start, score in gated:
|
| 947 |
+
if start not in selected_starts:
|
| 948 |
+
selected_starts.append(start)
|
| 949 |
+
if len(selected_starts) >= n_chunks:
|
| 950 |
+
break
|
| 951 |
+
|
| 952 |
+
selected_starts.sort()
|
| 953 |
+
return [y[s : s + chunk_samples] for s in selected_starts]
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
# ---- BPM helpers --------------------------------------------------------
|
| 957 |
+
|
| 958 |
def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float:
|
| 959 |
"""Fold BPM into the musical sweet-spot range [lo, hi]."""
|
| 960 |
if bpm <= 0:
|
|
|
|
| 969 |
return candidate
|
| 970 |
|
| 971 |
|
| 972 |
+
def _bpm_core_ensemble(y, sr) -> list:
|
| 973 |
+
"""Run the 3-method BPM ensemble on a single audio buffer (mid/sas).
|
| 974 |
+
|
| 975 |
+
Returns a list of octave-corrected BPM estimates.
|
| 976 |
+
"""
|
| 977 |
import librosa
|
| 978 |
import numpy as np
|
| 979 |
|
| 980 |
+
estimates = []
|
| 981 |
+
|
| 982 |
+
# Method A: beat_track
|
| 983 |
try:
|
| 984 |
+
tempo_a, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 985 |
+
val_a = float(np.atleast_1d(tempo_a)[0])
|
| 986 |
+
if val_a > 0:
|
| 987 |
+
estimates.append(_octave_correct_bpm(val_a))
|
| 988 |
+
except Exception:
|
| 989 |
+
pass
|
| 990 |
+
|
| 991 |
+
# Method B: tempogram peak
|
| 992 |
+
try:
|
| 993 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
| 994 |
+
tempogram = librosa.feature.tempogram(onset_envelope=onset_env, sr=sr)
|
| 995 |
+
avg_tempogram = np.mean(tempogram, axis=1)
|
| 996 |
+
bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr)
|
| 997 |
+
valid = (bpm_axis >= 30) & (bpm_axis <= 300)
|
| 998 |
+
if np.any(valid):
|
| 999 |
+
masked = avg_tempogram.copy()
|
| 1000 |
+
masked[~valid] = 0
|
| 1001 |
+
peak_idx = np.argmax(masked)
|
| 1002 |
+
val_b = float(bpm_axis[peak_idx])
|
| 1003 |
+
if val_b > 0:
|
| 1004 |
+
estimates.append(_octave_correct_bpm(val_b))
|
| 1005 |
except Exception:
|
| 1006 |
pass
|
|
|
|
| 1007 |
|
| 1008 |
+
# Method C: onset autocorrelation
|
| 1009 |
+
try:
|
| 1010 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
| 1011 |
+
ac = librosa.autocorrelate(onset_env, max_size=len(onset_env))
|
| 1012 |
+
hop = 512
|
| 1013 |
+
min_lag = int(60.0 * sr / (300.0 * hop))
|
| 1014 |
+
max_lag = int(60.0 * sr / (30.0 * hop))
|
| 1015 |
+
max_lag = min(max_lag, len(ac) - 1)
|
| 1016 |
+
if min_lag < max_lag and max_lag > 0:
|
| 1017 |
+
segment = ac[min_lag:max_lag + 1]
|
| 1018 |
+
peak_offset = np.argmax(segment)
|
| 1019 |
+
peak_lag = min_lag + peak_offset
|
| 1020 |
+
if peak_lag > 0:
|
| 1021 |
+
val_c = 60.0 * sr / (peak_lag * hop)
|
| 1022 |
+
if val_c > 0:
|
| 1023 |
+
estimates.append(_octave_correct_bpm(val_c))
|
| 1024 |
+
except Exception:
|
| 1025 |
+
pass
|
| 1026 |
+
|
| 1027 |
+
return estimates
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
def _bpm_consensus(estimates: list) -> Tuple[Optional[int], str]:
|
| 1031 |
+
"""Find consensus BPM from a list of estimates + assign confidence."""
|
| 1032 |
+
import numpy as np
|
| 1033 |
+
|
| 1034 |
+
if not estimates:
|
| 1035 |
+
return None, "low"
|
| 1036 |
+
|
| 1037 |
+
estimates_arr = np.array(estimates)
|
| 1038 |
+
best_cluster = []
|
| 1039 |
+
for ref in estimates_arr:
|
| 1040 |
+
cluster = [e for e in estimates_arr
|
| 1041 |
+
if abs(e - ref) / max(ref, 1) < 0.08]
|
| 1042 |
+
if len(cluster) > len(best_cluster):
|
| 1043 |
+
best_cluster = cluster
|
| 1044 |
+
|
| 1045 |
+
consensus = float(np.median(best_cluster)) if best_cluster else estimates[0]
|
| 1046 |
+
bpm = int(round(consensus))
|
| 1047 |
+
if bpm <= 0:
|
| 1048 |
+
return None, "low"
|
| 1049 |
+
|
| 1050 |
+
n_agree = len(best_cluster)
|
| 1051 |
+
n_total = len(estimates)
|
| 1052 |
+
if n_total >= 6:
|
| 1053 |
+
# sas thresholds (many data points)
|
| 1054 |
+
if n_agree / n_total >= 0.7:
|
| 1055 |
+
confidence = "high"
|
| 1056 |
+
elif n_agree / n_total >= 0.4:
|
| 1057 |
+
confidence = "medium"
|
| 1058 |
+
else:
|
| 1059 |
+
confidence = "low"
|
| 1060 |
+
else:
|
| 1061 |
+
# mid thresholds
|
| 1062 |
+
if n_agree >= 3:
|
| 1063 |
+
confidence = "high"
|
| 1064 |
+
elif n_agree >= 2:
|
| 1065 |
+
confidence = "medium"
|
| 1066 |
+
else:
|
| 1067 |
+
confidence = "low"
|
| 1068 |
|
| 1069 |
+
return bpm, confidence
|
| 1070 |
+
|
| 1071 |
+
|
| 1072 |
+
# ---- Unified BPM detection ---------------------------------------------
|
| 1073 |
+
|
| 1074 |
+
def _detect_bpm(y, sr, mode: str = "faf") -> Tuple[Optional[int], str]:
|
| 1075 |
+
"""Detect BPM with quality controlled by mode.
|
| 1076 |
+
|
| 1077 |
+
faf: Single beat_track + octave correction.
|
| 1078 |
+
mid: 3-method ensemble (beat_track + tempogram + onset-AC).
|
| 1079 |
+
sas: mid ensemble + PLP + multi-hop + chunked analysis.
|
| 1080 |
+
|
| 1081 |
+
Returns (bpm, confidence).
|
| 1082 |
+
"""
|
| 1083 |
import librosa
|
| 1084 |
import numpy as np
|
| 1085 |
|
| 1086 |
try:
|
| 1087 |
+
# faf: single method
|
| 1088 |
+
if mode == "faf":
|
| 1089 |
+
try:
|
| 1090 |
+
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 1091 |
+
val = float(np.atleast_1d(tempo)[0])
|
| 1092 |
+
if val > 0:
|
| 1093 |
+
bpm = int(round(_octave_correct_bpm(val)))
|
| 1094 |
+
logger.info("BPM faf: %d (raw: %.1f)", bpm, val)
|
| 1095 |
+
return bpm, "low"
|
| 1096 |
+
except Exception:
|
| 1097 |
+
pass
|
| 1098 |
+
return None, "low"
|
| 1099 |
+
|
| 1100 |
+
# mid: 3-method ensemble
|
| 1101 |
+
estimates = _bpm_core_ensemble(y, sr)
|
| 1102 |
+
|
| 1103 |
+
# sas: additional techniques
|
| 1104 |
+
ibi_cv = 0.5
|
| 1105 |
+
if mode == "sas":
|
| 1106 |
+
# PLP (Predominant Local Pulse)
|
| 1107 |
+
try:
|
| 1108 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
| 1109 |
+
pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr)
|
| 1110 |
+
plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse))
|
| 1111 |
+
hop = 512
|
| 1112 |
+
min_lag = int(60.0 * sr / (300.0 * hop))
|
| 1113 |
+
max_lag = int(60.0 * sr / (30.0 * hop))
|
| 1114 |
+
max_lag = min(max_lag, len(plp_ac) - 1)
|
| 1115 |
+
if min_lag < max_lag and max_lag > 0:
|
| 1116 |
+
seg = plp_ac[min_lag:max_lag + 1]
|
| 1117 |
+
peak_lag = min_lag + np.argmax(seg)
|
| 1118 |
+
if peak_lag > 0:
|
| 1119 |
+
plp_bpm = 60.0 * sr / (peak_lag * hop)
|
| 1120 |
+
if plp_bpm > 0:
|
| 1121 |
+
estimates.append(_octave_correct_bpm(plp_bpm))
|
| 1122 |
+
except Exception:
|
| 1123 |
+
pass
|
| 1124 |
+
|
| 1125 |
+
# Multi-hop beat_track (256, 1024)
|
| 1126 |
+
for extra_hop in (256, 1024):
|
| 1127 |
+
try:
|
| 1128 |
+
tempo_h, _ = librosa.beat.beat_track(y=y, sr=sr, hop_length=extra_hop)
|
| 1129 |
+
val_h = float(np.atleast_1d(tempo_h)[0])
|
| 1130 |
+
if val_h > 0:
|
| 1131 |
+
estimates.append(_octave_correct_bpm(val_h))
|
| 1132 |
+
except Exception:
|
| 1133 |
+
pass
|
| 1134 |
+
|
| 1135 |
+
# Chunked ensemble
|
| 1136 |
+
chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True)
|
| 1137 |
+
for chunk in chunks:
|
| 1138 |
+
chunk_estimates = _bpm_core_ensemble(chunk, sr)
|
| 1139 |
+
estimates.extend(chunk_estimates)
|
| 1140 |
+
|
| 1141 |
+
# IBI stability
|
| 1142 |
+
try:
|
| 1143 |
+
_, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
|
| 1144 |
+
if beat_frames is not None and len(beat_frames) > 4:
|
| 1145 |
+
beat_times = librosa.frames_to_time(beat_frames, sr=sr)
|
| 1146 |
+
ibis = np.diff(beat_times)
|
| 1147 |
+
ibi_cv = float(np.std(ibis) / (np.mean(ibis) + 1e-10))
|
| 1148 |
+
else:
|
| 1149 |
+
ibi_cv = 0.5
|
| 1150 |
+
except Exception:
|
| 1151 |
+
ibi_cv = 0.5
|
| 1152 |
+
|
| 1153 |
+
bpm, confidence = _bpm_consensus(estimates)
|
| 1154 |
+
|
| 1155 |
+
# sas: IBI stability can upgrade medium->high or downgrade
|
| 1156 |
+
if mode == "sas" and bpm is not None:
|
| 1157 |
+
if ibi_cv < 0.10 and confidence == "medium":
|
| 1158 |
+
confidence = "high"
|
| 1159 |
+
elif ibi_cv > 0.30 and confidence == "high":
|
| 1160 |
+
confidence = "medium"
|
| 1161 |
+
|
| 1162 |
+
logger.info(
|
| 1163 |
+
"BPM [%s]: %s (estimates=%s, conf=%s)",
|
| 1164 |
+
mode, bpm,
|
| 1165 |
+
[round(e, 1) for e in estimates[:10]],
|
| 1166 |
+
confidence,
|
| 1167 |
+
)
|
| 1168 |
+
return bpm, confidence
|
| 1169 |
+
|
| 1170 |
+
except Exception as exc:
|
| 1171 |
+
logger.warning("BPM detection failed: %s", exc)
|
| 1172 |
+
return None, "low"
|
| 1173 |
+
|
| 1174 |
+
|
| 1175 |
+
# ---- Key detection helpers ----------------------------------------------
|
| 1176 |
+
|
| 1177 |
+
def _best_key_for_profile(chroma_avg, major_profile, minor_profile):
|
| 1178 |
+
"""Find the best key match for a single profile family.
|
| 1179 |
+
|
| 1180 |
+
Returns (key_label, correlation).
|
| 1181 |
+
"""
|
| 1182 |
+
import numpy as np
|
| 1183 |
+
|
| 1184 |
+
major_norm = np.array(major_profile, dtype=float)
|
| 1185 |
+
major_norm = major_norm / major_norm.sum()
|
| 1186 |
+
minor_norm = np.array(minor_profile, dtype=float)
|
| 1187 |
+
minor_norm = minor_norm / minor_norm.sum()
|
| 1188 |
+
|
| 1189 |
+
best_corr = -2.0
|
| 1190 |
+
best_key = "C major"
|
| 1191 |
+
|
| 1192 |
+
for shift in range(12):
|
| 1193 |
+
rotated = np.roll(chroma_avg, -shift)
|
| 1194 |
+
corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1])
|
| 1195 |
+
if corr_maj > best_corr:
|
| 1196 |
+
best_corr = corr_maj
|
| 1197 |
+
best_key = f"{_PITCH_CLASSES[shift]} major"
|
| 1198 |
+
corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1])
|
| 1199 |
+
if corr_min > best_corr:
|
| 1200 |
+
best_corr = corr_min
|
| 1201 |
+
best_key = f"{_PITCH_CLASSES[shift]} minor"
|
| 1202 |
+
|
| 1203 |
+
return best_key, best_corr
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def _key_votes_from_chroma(chroma_avg, profiles=None) -> list:
|
| 1207 |
+
"""Vote on key from a single chroma vector using specified profiles.
|
| 1208 |
+
|
| 1209 |
+
Returns list of (key_label, correlation) -- one per profile family.
|
| 1210 |
+
"""
|
| 1211 |
+
if profiles is None:
|
| 1212 |
+
profiles = _KEY_PROFILES
|
| 1213 |
+
results = []
|
| 1214 |
+
for name, pf in profiles.items():
|
| 1215 |
+
key_label, corr = _best_key_for_profile(
|
| 1216 |
+
chroma_avg, pf["major"], pf["minor"],
|
| 1217 |
+
)
|
| 1218 |
+
results.append((key_label, corr))
|
| 1219 |
+
return results
|
| 1220 |
+
|
| 1221 |
+
|
| 1222 |
+
def _energy_weighted_chroma(chroma, y_harmonic):
|
| 1223 |
+
"""Compute an energy-weighted average chroma vector.
|
| 1224 |
+
|
| 1225 |
+
Returns normalized chroma_avg or None if zero energy.
|
| 1226 |
+
"""
|
| 1227 |
+
import librosa
|
| 1228 |
+
import numpy as np
|
| 1229 |
+
|
| 1230 |
+
rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512)
|
| 1231 |
+
rms_vec = rms[0]
|
| 1232 |
+
min_len = min(chroma.shape[1], len(rms_vec))
|
| 1233 |
+
chroma = chroma[:, :min_len]
|
| 1234 |
+
rms_vec = rms_vec[:min_len]
|
| 1235 |
+
|
| 1236 |
+
weights = rms_vec / (rms_vec.sum() + 1e-10)
|
| 1237 |
+
chroma_avg = (chroma * weights[None, :]).sum(axis=1)
|
| 1238 |
+
|
| 1239 |
+
s = chroma_avg.sum()
|
| 1240 |
+
if s == 0:
|
| 1241 |
return None
|
| 1242 |
+
return chroma_avg / s
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
# ---- Unified key detection ----------------------------------------------
|
| 1246 |
+
|
| 1247 |
+
def _detect_key(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]:
|
| 1248 |
+
"""Detect musical key with quality controlled by mode.
|
| 1249 |
+
|
| 1250 |
+
faf: Single Krumhansl profile on chroma_cens.
|
| 1251 |
+
mid: 3-profile x energy-weighted chroma_cens x 8s segment voting.
|
| 1252 |
+
sas: mid + multi-chroma fusion + tonnetz + tuning correction +
|
| 1253 |
+
ending resolution + chunked voting.
|
| 1254 |
+
|
| 1255 |
+
Returns (key, confidence).
|
| 1256 |
+
"""
|
| 1257 |
+
import librosa
|
| 1258 |
+
import numpy as np
|
| 1259 |
+
from collections import Counter
|
| 1260 |
+
|
| 1261 |
+
try:
|
| 1262 |
+
# Harmonic enhancement
|
| 1263 |
+
margin = 4.0 if mode != "faf" else 2.0
|
| 1264 |
+
y_harmonic = librosa.effects.harmonic(y, margin=margin)
|
| 1265 |
+
|
| 1266 |
+
# sas: tuning correction
|
| 1267 |
+
tuning = 0.0
|
| 1268 |
+
if mode == "sas":
|
| 1269 |
+
try:
|
| 1270 |
+
tuning = float(librosa.estimate_tuning(y=y_harmonic, sr=sr))
|
| 1271 |
+
except Exception:
|
| 1272 |
+
tuning = 0.0
|
| 1273 |
+
|
| 1274 |
+
# faf: single chroma, single profile
|
| 1275 |
+
if mode == "faf":
|
| 1276 |
+
chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr)
|
| 1277 |
+
chroma_avg = _energy_weighted_chroma(chroma, y_harmonic)
|
| 1278 |
+
if chroma_avg is None:
|
| 1279 |
+
return None, "low"
|
| 1280 |
+
kr = _KEY_PROFILES["krumhansl"]
|
| 1281 |
+
key_label, corr = _best_key_for_profile(
|
| 1282 |
+
chroma_avg, kr["major"], kr["minor"],
|
| 1283 |
+
)
|
| 1284 |
+
logger.info("Key faf: %s (corr=%.3f)", key_label, corr)
|
| 1285 |
+
return key_label, "low"
|
| 1286 |
+
|
| 1287 |
+
# mid / sas: multi-profile voting
|
| 1288 |
+
all_votes = []
|
| 1289 |
+
all_weights = []
|
| 1290 |
+
|
| 1291 |
+
if mode == "sas":
|
| 1292 |
+
chroma_types = {
|
| 1293 |
+
"cens": lambda: librosa.feature.chroma_cens(
|
| 1294 |
+
y=y_harmonic, sr=sr, tuning=tuning,
|
| 1295 |
+
),
|
| 1296 |
+
"cqt": lambda: librosa.feature.chroma_cqt(
|
| 1297 |
+
y=y_harmonic, sr=sr, tuning=tuning,
|
| 1298 |
+
),
|
| 1299 |
+
"stft": lambda: librosa.feature.chroma_stft(
|
| 1300 |
+
y=y_harmonic, sr=sr, tuning=tuning,
|
| 1301 |
+
),
|
| 1302 |
+
}
|
| 1303 |
+
else:
|
| 1304 |
+
chroma_types = {
|
| 1305 |
+
"cens": lambda: librosa.feature.chroma_cens(
|
| 1306 |
+
y=y_harmonic, sr=sr,
|
| 1307 |
+
),
|
| 1308 |
+
}
|
| 1309 |
+
|
| 1310 |
+
for chroma_name, chroma_fn in chroma_types.items():
|
| 1311 |
+
try:
|
| 1312 |
+
chroma = chroma_fn()
|
| 1313 |
+
except Exception:
|
| 1314 |
+
continue
|
| 1315 |
+
|
| 1316 |
+
chroma_avg = _energy_weighted_chroma(chroma, y_harmonic)
|
| 1317 |
+
if chroma_avg is None:
|
| 1318 |
+
continue
|
| 1319 |
+
|
| 1320 |
+
# Global multi-profile vote
|
| 1321 |
+
for key_label, corr in _key_votes_from_chroma(chroma_avg):
|
| 1322 |
+
all_votes.append(key_label)
|
| 1323 |
+
all_weights.append(1.0)
|
| 1324 |
+
|
| 1325 |
+
# Segment-based voting
|
| 1326 |
+
rms = librosa.feature.rms(
|
| 1327 |
+
y=y_harmonic, frame_length=2048, hop_length=512,
|
| 1328 |
+
)
|
| 1329 |
+
rms_vec = rms[0]
|
| 1330 |
+
min_len = min(chroma.shape[1], len(rms_vec))
|
| 1331 |
+
chroma_s = chroma[:, :min_len]
|
| 1332 |
+
rms_s = rms_vec[:min_len]
|
| 1333 |
+
|
| 1334 |
+
seg_frames = int(8.0 * sr / 512)
|
| 1335 |
+
n_segments = max(1, chroma_s.shape[1] // seg_frames)
|
| 1336 |
+
|
| 1337 |
+
for seg_i in range(n_segments):
|
| 1338 |
+
start = seg_i * seg_frames
|
| 1339 |
+
end = min(start + seg_frames, chroma_s.shape[1])
|
| 1340 |
+
seg_chroma = chroma_s[:, start:end]
|
| 1341 |
+
seg_w = rms_s[start:end]
|
| 1342 |
+
|
| 1343 |
+
w_sum = seg_w.sum()
|
| 1344 |
+
if w_sum < 1e-10:
|
| 1345 |
+
continue
|
| 1346 |
+
|
| 1347 |
+
seg_w_norm = seg_w / w_sum
|
| 1348 |
+
seg_avg = (seg_chroma * seg_w_norm[None, :]).sum(axis=1)
|
| 1349 |
+
s = seg_avg.sum()
|
| 1350 |
+
if s < 1e-10:
|
| 1351 |
+
continue
|
| 1352 |
+
seg_avg = seg_avg / s
|
| 1353 |
+
|
| 1354 |
+
for key_label, _ in _key_votes_from_chroma(seg_avg):
|
| 1355 |
+
all_votes.append(key_label)
|
| 1356 |
+
all_weights.append(1.0)
|
| 1357 |
+
|
| 1358 |
+
# sas-only extras
|
| 1359 |
+
if mode == "sas":
|
| 1360 |
+
# Tonnetz -- weighted vote for major/minor disambiguation
|
| 1361 |
+
try:
|
| 1362 |
+
tonnetz = librosa.feature.tonnetz(y=y_harmonic, sr=sr)
|
| 1363 |
+
tonnetz_avg = np.mean(tonnetz, axis=1)
|
| 1364 |
+
major_energy = float(np.sum(tonnetz_avg[4:6] ** 2))
|
| 1365 |
+
minor_energy = float(np.sum(tonnetz_avg[2:4] ** 2))
|
| 1366 |
+
tonnetz_ratio = major_energy / (minor_energy + 1e-10)
|
| 1367 |
+
|
| 1368 |
+
if all_votes:
|
| 1369 |
+
temp_counts = Counter(all_votes)
|
| 1370 |
+
leader = temp_counts.most_common(1)[0][0]
|
| 1371 |
+
leader_is_major = "major" in leader
|
| 1372 |
+
tonnetz_says_major = tonnetz_ratio > 1.0
|
| 1373 |
+
|
| 1374 |
+
if leader_is_major == tonnetz_says_major:
|
| 1375 |
+
all_votes.extend([leader] * 3)
|
| 1376 |
+
all_weights.extend([1.5] * 3)
|
| 1377 |
+
else:
|
| 1378 |
+
alt_mode = "minor" if leader_is_major else "major"
|
| 1379 |
+
chroma_cens = librosa.feature.chroma_cens(
|
| 1380 |
+
y=y_harmonic, sr=sr, tuning=tuning,
|
| 1381 |
+
)
|
| 1382 |
+
ca = _energy_weighted_chroma(chroma_cens, y_harmonic)
|
| 1383 |
+
if ca is not None:
|
| 1384 |
+
for name, pf in _KEY_PROFILES.items():
|
| 1385 |
+
prof = np.array(pf[alt_mode], dtype=float)
|
| 1386 |
+
prof_norm = prof / prof.sum()
|
| 1387 |
+
best_corr = -2.0
|
| 1388 |
+
best_k = ""
|
| 1389 |
+
for shift in range(12):
|
| 1390 |
+
rotated = np.roll(ca, -shift)
|
| 1391 |
+
c = float(np.corrcoef(rotated, prof_norm)[0, 1])
|
| 1392 |
+
if c > best_corr:
|
| 1393 |
+
best_corr = c
|
| 1394 |
+
best_k = f"{_PITCH_CLASSES[shift]} {alt_mode}"
|
| 1395 |
+
if best_k:
|
| 1396 |
+
all_votes.append(best_k)
|
| 1397 |
+
all_weights.append(1.0)
|
| 1398 |
+
except Exception:
|
| 1399 |
+
pass
|
| 1400 |
+
|
| 1401 |
+
# Ending resolution -- last ~5 s weighted extra
|
| 1402 |
+
try:
|
| 1403 |
+
end_samples = min(int(5.0 * sr), len(y_harmonic))
|
| 1404 |
+
y_end = y_harmonic[-end_samples:]
|
| 1405 |
+
chroma_end = librosa.feature.chroma_cens(
|
| 1406 |
+
y=y_end, sr=sr, tuning=tuning,
|
| 1407 |
+
)
|
| 1408 |
+
end_avg = np.mean(chroma_end, axis=1)
|
| 1409 |
+
s = end_avg.sum()
|
| 1410 |
+
if s > 1e-10:
|
| 1411 |
+
end_avg = end_avg / s
|
| 1412 |
+
for key_label, _ in _key_votes_from_chroma(end_avg):
|
| 1413 |
+
all_votes.append(key_label)
|
| 1414 |
+
all_weights.append(2.0)
|
| 1415 |
+
except Exception:
|
| 1416 |
+
pass
|
| 1417 |
+
|
| 1418 |
+
# Chunked voting
|
| 1419 |
+
chunks = _select_chunks(
|
| 1420 |
+
y_harmonic, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=False,
|
| 1421 |
+
)
|
| 1422 |
+
for chunk in chunks:
|
| 1423 |
+
try:
|
| 1424 |
+
ch_chroma = librosa.feature.chroma_cens(
|
| 1425 |
+
y=chunk, sr=sr, tuning=tuning,
|
| 1426 |
+
)
|
| 1427 |
+
ch_avg = _energy_weighted_chroma(ch_chroma, chunk)
|
| 1428 |
+
if ch_avg is not None:
|
| 1429 |
+
for key_label, _ in _key_votes_from_chroma(ch_avg):
|
| 1430 |
+
all_votes.append(key_label)
|
| 1431 |
+
all_weights.append(1.0)
|
| 1432 |
+
except Exception:
|
| 1433 |
+
pass
|
| 1434 |
+
|
| 1435 |
+
# Final weighted majority vote
|
| 1436 |
+
if not all_votes:
|
| 1437 |
+
return None, "low"
|
| 1438 |
+
|
| 1439 |
+
weighted_counts = {}
|
| 1440 |
+
for vote, w in zip(all_votes, all_weights):
|
| 1441 |
+
weighted_counts[vote] = weighted_counts.get(vote, 0.0) + w
|
| 1442 |
+
|
| 1443 |
+
best_key = max(weighted_counts, key=weighted_counts.get)
|
| 1444 |
+
total_weight = sum(all_weights)
|
| 1445 |
+
best_weight = weighted_counts[best_key]
|
| 1446 |
+
share = best_weight / total_weight
|
| 1447 |
+
|
| 1448 |
+
if share >= 0.55:
|
| 1449 |
+
confidence = "high"
|
| 1450 |
+
elif share >= 0.35:
|
| 1451 |
+
confidence = "medium"
|
| 1452 |
+
else:
|
| 1453 |
+
confidence = "low"
|
| 1454 |
+
|
| 1455 |
+
logger.info(
|
| 1456 |
+
"Key [%s]: %s (share=%.0f%%, votes=%d, conf=%s)",
|
| 1457 |
+
mode, best_key, share * 100, len(all_votes), confidence,
|
| 1458 |
+
)
|
| 1459 |
+
return best_key, confidence
|
| 1460 |
+
|
| 1461 |
+
except Exception as exc:
|
| 1462 |
+
logger.warning("Key detection failed: %s", exc)
|
| 1463 |
+
return None, "low"
|
| 1464 |
+
|
| 1465 |
+
|
| 1466 |
+
# ---- Time-signature helpers ---------------------------------------------
|
| 1467 |
+
|
| 1468 |
+
def _timesig_core_scores(y, sr) -> dict:
|
| 1469 |
+
"""Compute 3-signal time-signature scores on a single buffer (mid/sas).
|
| 1470 |
|
| 1471 |
+
Returns dict mapping signature labels to raw scores.
|
| 1472 |
+
"""
|
| 1473 |
+
import librosa
|
| 1474 |
+
import numpy as np
|
| 1475 |
+
|
| 1476 |
+
scores = {}
|
| 1477 |
+
|
| 1478 |
+
tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)
|
| 1479 |
+
if beat_frames is None or len(beat_frames) < 8:
|
| 1480 |
+
return scores
|
| 1481 |
+
|
| 1482 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
| 1483 |
+
beat_strengths = onset_env[beat_frames[beat_frames < len(onset_env)]]
|
| 1484 |
+
if len(beat_strengths) < 8:
|
| 1485 |
+
return scores
|
| 1486 |
+
|
| 1487 |
+
# Signal 1: Accent pattern analysis
|
| 1488 |
+
for label, grouping in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
|
| 1489 |
+
if len(beat_strengths) < grouping * 2:
|
| 1490 |
+
scores[label] = 0.0
|
| 1491 |
+
continue
|
| 1492 |
+
usable = len(beat_strengths) - (len(beat_strengths) % grouping)
|
| 1493 |
+
grouped = beat_strengths[:usable].reshape(-1, grouping)
|
| 1494 |
+
downbeat_mean = float(np.mean(grouped[:, 0]))
|
| 1495 |
+
offbeat_mean = float(np.mean(grouped[:, 1:]))
|
| 1496 |
+
contrast = downbeat_mean / offbeat_mean if offbeat_mean > 0 else 1.0
|
| 1497 |
+
scores[label] = contrast
|
| 1498 |
+
|
| 1499 |
+
# Signal 2: Autocorrelation at meter periods
|
| 1500 |
+
hop = 512
|
| 1501 |
+
beat_times = librosa.frames_to_time(beat_frames, sr=sr)
|
| 1502 |
+
intervals = np.diff(beat_times)
|
| 1503 |
+
if len(intervals) > 0:
|
| 1504 |
+
median_interval = float(np.median(intervals))
|
| 1505 |
+
beat_period = int(round(median_interval * sr / hop))
|
| 1506 |
+
if beat_period > 0:
|
| 1507 |
+
ac = librosa.autocorrelate(onset_env, max_size=len(onset_env))
|
| 1508 |
+
for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
|
| 1509 |
+
period = beat_period * mult
|
| 1510 |
+
if period < len(ac):
|
| 1511 |
+
lo = max(0, period - 2)
|
| 1512 |
+
hi = min(len(ac), period + 3)
|
| 1513 |
+
ac_score = float(np.mean(ac[lo:hi]))
|
| 1514 |
+
if ac[0] > 0:
|
| 1515 |
+
ac_score /= float(ac[0])
|
| 1516 |
+
scores[label] = scores.get(label, 0.0) + ac_score
|
| 1517 |
+
|
| 1518 |
+
# Signal 3: Beat-strength variance ratio
|
| 1519 |
+
for label, grouping in [("3/4", 3), ("4/4", 4)]:
|
| 1520 |
+
usable = len(beat_strengths) - (len(beat_strengths) % grouping)
|
| 1521 |
+
if usable >= grouping * 2:
|
| 1522 |
+
grouped = beat_strengths[:usable].reshape(-1, grouping)
|
| 1523 |
+
row_vars = np.var(grouped, axis=1)
|
| 1524 |
+
scores[label] = scores.get(label, 0.0) + float(np.mean(row_vars))
|
| 1525 |
+
|
| 1526 |
+
return scores
|
| 1527 |
+
|
| 1528 |
+
|
| 1529 |
+
# ---- Unified time-signature detection -----------------------------------
|
| 1530 |
+
|
| 1531 |
+
def _detect_time_sig(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]:
|
| 1532 |
+
"""Estimate time signature with quality controlled by mode.
|
| 1533 |
+
|
| 1534 |
+
faf: Hardcoded "4/4" (correct ~80%+ of the time).
|
| 1535 |
+
mid: Beat-sync accent + AC + variance + 4/4 prior.
|
| 1536 |
+
sas: mid signals + PLP periodicity + multi-band onset +
|
| 1537 |
+
tempogram harmonic ratios + chunked voting.
|
| 1538 |
+
|
| 1539 |
+
Returns (signature, confidence).
|
| 1540 |
+
"""
|
| 1541 |
+
if mode == "faf":
|
| 1542 |
+
return "4/4", "low"
|
| 1543 |
+
|
| 1544 |
+
import librosa
|
| 1545 |
+
import numpy as np
|
| 1546 |
+
|
| 1547 |
+
try:
|
| 1548 |
+
# mid: core 3-signal scoring
|
| 1549 |
+
scores = _timesig_core_scores(y, sr)
|
| 1550 |
|
| 1551 |
+
# sas: additional techniques
|
| 1552 |
+
if mode == "sas":
|
| 1553 |
+
onset_env = librosa.onset.onset_strength(y=y, sr=sr)
|
| 1554 |
+
|
| 1555 |
+
# PLP periodicity
|
| 1556 |
+
try:
|
| 1557 |
+
pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr)
|
| 1558 |
+
plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse))
|
| 1559 |
+
tempo_est, _ = librosa.beat.beat_track(y=y, sr=sr)
|
| 1560 |
+
tempo_val = float(np.atleast_1d(tempo_est)[0])
|
| 1561 |
+
if tempo_val > 0:
|
| 1562 |
+
hop = 512
|
| 1563 |
+
bp = int(round(60.0 / tempo_val * sr / hop))
|
| 1564 |
+
if bp > 0:
|
| 1565 |
+
for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]:
|
| 1566 |
+
lag = bp * mult
|
| 1567 |
+
if lag < len(plp_ac):
|
| 1568 |
+
lo = max(0, lag - 2)
|
| 1569 |
+
hi = min(len(plp_ac), lag + 3)
|
| 1570 |
+
s = float(np.mean(plp_ac[lo:hi]))
|
| 1571 |
+
if plp_ac[0] > 0:
|
| 1572 |
+
s /= float(plp_ac[0])
|
| 1573 |
+
scores[label] = scores.get(label, 0.0) + s
|
| 1574 |
+
except Exception:
|
| 1575 |
+
pass
|
| 1576 |
+
|
| 1577 |
+
# Multi-band onset analysis (low/mid/high)
|
| 1578 |
+
try:
|
| 1579 |
+
S = np.abs(librosa.stft(y))
|
| 1580 |
+
n_bins = S.shape[0]
|
| 1581 |
+
third = n_bins // 3
|
| 1582 |
+
bands = {
|
| 1583 |
+
"low": S[:third, :],
|
| 1584 |
+
"mid_band": S[third:2*third, :],
|
| 1585 |
+
"high": S[2*third:, :],
|
| 1586 |
+
}
|
| 1587 |
+
for band_name, band_S in bands.items():
|
| 1588 |
+
band_onset = librosa.onset.onset_strength(S=band_S, sr=sr)
|
| 1589 |
+
band_ac = librosa.autocorrelate(
|
| 1590 |
+
band_onset, max_size=len(band_onset),
|
| 1591 |
+
)
|
| 1592 |
+
tempo_val2 = float(np.atleast_1d(tempo_est)[0])
|
| 1593 |
+
if tempo_val2 > 0:
|
| 1594 |
+
hop = 512
|
| 1595 |
+
bp2 = int(round(60.0 / tempo_val2 * sr / hop))
|
| 1596 |
+
if bp2 > 0 and band_ac[0] > 0:
|
| 1597 |
+
for label, mult in [("3/4", 3), ("4/4", 4)]:
|
| 1598 |
+
lag = bp2 * mult
|
| 1599 |
+
if lag < len(band_ac):
|
| 1600 |
+
lo = max(0, lag - 2)
|
| 1601 |
+
hi = min(len(band_ac), lag + 3)
|
| 1602 |
+
s = float(np.mean(band_ac[lo:hi]))
|
| 1603 |
+
s /= float(band_ac[0])
|
| 1604 |
+
w = 1.5 if band_name == "low" else 1.0
|
| 1605 |
+
scores[label] = scores.get(label, 0.0) + s * w
|
| 1606 |
+
except Exception:
|
| 1607 |
+
pass
|
| 1608 |
+
|
| 1609 |
+
# Tempogram harmonic ratios
|
| 1610 |
+
try:
|
| 1611 |
+
tempogram = librosa.feature.tempogram(
|
| 1612 |
+
onset_envelope=onset_env, sr=sr,
|
| 1613 |
+
)
|
| 1614 |
+
avg_tg = np.mean(tempogram, axis=1)
|
| 1615 |
+
bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr)
|
| 1616 |
+
if tempo_val > 0:
|
| 1617 |
+
for mult_label, t_mult in [("duple", 2.0), ("triple", 3.0)]:
|
| 1618 |
+
target_bpm = tempo_val * t_mult
|
| 1619 |
+
if target_bpm < 300:
|
| 1620 |
+
idx = np.argmin(np.abs(bpm_axis - target_bpm))
|
| 1621 |
+
energy = float(avg_tg[idx])
|
| 1622 |
+
base_idx = np.argmin(np.abs(bpm_axis - tempo_val))
|
| 1623 |
+
base_energy = float(avg_tg[base_idx]) + 1e-10
|
| 1624 |
+
ratio = energy / base_energy
|
| 1625 |
+
if t_mult == 2.0:
|
| 1626 |
+
scores["4/4"] = scores.get("4/4", 0.0) + ratio
|
| 1627 |
+
else:
|
| 1628 |
+
scores["3/4"] = scores.get("3/4", 0.0) + ratio
|
| 1629 |
+
except Exception:
|
| 1630 |
+
pass
|
| 1631 |
+
|
| 1632 |
+
# Chunked voting
|
| 1633 |
+
chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True)
|
| 1634 |
+
chunk_votes = []
|
| 1635 |
+
for chunk in chunks:
|
| 1636 |
+
cs = _timesig_core_scores(chunk, sr)
|
| 1637 |
+
if cs:
|
| 1638 |
+
cs["4/4"] = cs.get("4/4", 0.0) * 1.15
|
| 1639 |
+
best_c = max(cs, key=cs.get)
|
| 1640 |
+
chunk_votes.append(best_c)
|
| 1641 |
+
for vote in chunk_votes:
|
| 1642 |
+
scores[vote] = scores.get(vote, 0.0) + 1.0
|
| 1643 |
+
|
| 1644 |
+
# Bayesian prior: bias toward 4/4
|
| 1645 |
+
scores["4/4"] = scores.get("4/4", 0.0) * 1.15
|
| 1646 |
+
|
| 1647 |
+
if not scores:
|
| 1648 |
+
return "4/4", "low"
|
| 1649 |
+
|
| 1650 |
+
best = max(scores, key=scores.get)
|
| 1651 |
+
|
| 1652 |
+
# Confidence: margin between top 2
|
| 1653 |
+
sorted_scores = sorted(scores.values(), reverse=True)
|
| 1654 |
+
if len(sorted_scores) >= 2 and sorted_scores[1] > 0:
|
| 1655 |
+
margin = sorted_scores[0] / sorted_scores[1]
|
| 1656 |
+
else:
|
| 1657 |
+
margin = 1.0
|
| 1658 |
+
|
| 1659 |
+
if margin > 1.4:
|
| 1660 |
+
confidence = "high"
|
| 1661 |
+
elif margin > 1.15:
|
| 1662 |
+
confidence = "medium"
|
| 1663 |
+
else:
|
| 1664 |
+
confidence = "low"
|
| 1665 |
+
|
| 1666 |
+
logger.info(
|
| 1667 |
+
"TimeSig [%s]: %s (scores=%s, margin=%.2f, conf=%s)",
|
| 1668 |
+
mode, best,
|
| 1669 |
+
{k: round(v, 3) for k, v in scores.items()},
|
| 1670 |
+
margin, confidence,
|
| 1671 |
+
)
|
| 1672 |
+
return best, confidence
|
| 1673 |
+
|
| 1674 |
+
except Exception as exc:
|
| 1675 |
+
logger.warning("Time signature detection failed: %s", exc)
|
| 1676 |
+
return "4/4", "low"
|
| 1677 |
|
| 1678 |
|
| 1679 |
def _sanitize_tag(value: str) -> str:
|
|
|
|
| 1743 |
return title or audio_path.stem, artist or ""
|
| 1744 |
|
| 1745 |
|
| 1746 |
+
def analyze_and_caption(
|
| 1747 |
+
audio_path: str,
|
| 1748 |
+
mode: str = "faf",
|
| 1749 |
+
device: str = "cpu",
|
| 1750 |
+
) -> Dict[str, Any]:
|
| 1751 |
"""Analyze an audio file and build a training caption.
|
| 1752 |
|
| 1753 |
+
Supports three quality modes:
|
| 1754 |
+
faf - CPU, ~2-3s/file. Single-method detection on raw mix.
|
| 1755 |
+
mid - ~12s/file. Demucs stems + 3-method ensemble.
|
| 1756 |
+
sas - ~30s/file. Deep multi-technique + chunked analysis.
|
| 1757 |
+
|
| 1758 |
+
For mid/sas, Demucs separates drums and harmonics stems first.
|
| 1759 |
+
On CPU, Demucs adds ~2-5 minutes per file.
|
| 1760 |
|
| 1761 |
Args:
|
| 1762 |
audio_path: Path to the audio file.
|
| 1763 |
+
mode: Analysis mode ("faf", "mid", or "sas").
|
| 1764 |
+
device: Torch device for Demucs ("cpu").
|
| 1765 |
|
| 1766 |
Returns:
|
| 1767 |
+
Dict with keys: caption, bpm, key, signature, lyrics, title, artist,
|
| 1768 |
+
confidence (dict of per-field confidence levels).
|
| 1769 |
"""
|
| 1770 |
import librosa
|
| 1771 |
import numpy as np
|
| 1772 |
|
| 1773 |
audio_path = Path(audio_path)
|
| 1774 |
|
| 1775 |
+
if mode not in _ANALYSIS_MODES:
|
| 1776 |
+
logger.warning("Unknown analysis mode '%s', falling back to 'faf'", mode)
|
| 1777 |
+
mode = "faf"
|
| 1778 |
+
|
| 1779 |
# Load audio once, reuse for all detectors
|
| 1780 |
try:
|
| 1781 |
y, sr = librosa.load(str(audio_path), sr=None, mono=True)
|
|
|
|
| 1793 |
"caption": f"A track by {artist}" if artist else f"A track titled {title}",
|
| 1794 |
"bpm": None, "key": None, "signature": "4/4",
|
| 1795 |
"lyrics": "[Instrumental]", "title": title, "artist": artist,
|
| 1796 |
+
"confidence": {},
|
| 1797 |
}
|
| 1798 |
|
| 1799 |
+
confidence = {}
|
| 1800 |
+
tmp_dir = None
|
| 1801 |
+
|
| 1802 |
+
try:
|
| 1803 |
+
if mode in ("mid", "sas"):
|
| 1804 |
+
# Demucs stem separation -- run BPM/timesig on drums,
|
| 1805 |
+
# key detection on harmonics
|
| 1806 |
+
tmp_dir = Path(tempfile.mkdtemp(prefix="ace_analysis_"))
|
| 1807 |
+
try:
|
| 1808 |
+
drums_path, harmonics_path = separate_stems(
|
| 1809 |
+
audio_path, tmp_dir, device=device,
|
| 1810 |
+
)
|
| 1811 |
+
# Load separated stems for analysis
|
| 1812 |
+
y_drums, sr_drums = librosa.load(
|
| 1813 |
+
str(drums_path), sr=None, mono=True,
|
| 1814 |
+
)
|
| 1815 |
+
y_harmonics, sr_harmonics = librosa.load(
|
| 1816 |
+
str(harmonics_path), sr=None, mono=True,
|
| 1817 |
+
)
|
| 1818 |
+
# Preprocess stems
|
| 1819 |
+
y_drums_trimmed, _ = librosa.effects.trim(y_drums, top_db=30)
|
| 1820 |
+
if len(y_drums_trimmed) >= sr_drums:
|
| 1821 |
+
y_drums = y_drums_trimmed
|
| 1822 |
+
peak_d = np.max(np.abs(y_drums))
|
| 1823 |
+
if peak_d > 0:
|
| 1824 |
+
y_drums = y_drums / peak_d
|
| 1825 |
+
|
| 1826 |
+
y_harm_trimmed, _ = librosa.effects.trim(y_harmonics, top_db=30)
|
| 1827 |
+
if len(y_harm_trimmed) >= sr_harmonics:
|
| 1828 |
+
y_harmonics = y_harm_trimmed
|
| 1829 |
+
peak_h = np.max(np.abs(y_harmonics))
|
| 1830 |
+
if peak_h > 0:
|
| 1831 |
+
y_harmonics = y_harmonics / peak_h
|
| 1832 |
+
|
| 1833 |
+
# BPM + time sig on drums stem
|
| 1834 |
+
bpm, bpm_conf = _detect_bpm(y_drums, sr_drums, mode)
|
| 1835 |
+
signature, sig_conf = _detect_time_sig(y_drums, sr_drums, mode)
|
| 1836 |
+
# Key on harmonics stem
|
| 1837 |
+
key, key_conf = _detect_key(y_harmonics, sr_harmonics, mode)
|
| 1838 |
+
|
| 1839 |
+
confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
|
| 1840 |
+
|
| 1841 |
+
except Exception as exc:
|
| 1842 |
+
logger.warning(
|
| 1843 |
+
"Demucs separation failed for %s: %s -- "
|
| 1844 |
+
"falling back to analysis on raw mix",
|
| 1845 |
+
audio_path.name, exc,
|
| 1846 |
+
)
|
| 1847 |
+
# Fallback: run detectors on raw mix
|
| 1848 |
+
bpm, bpm_conf = _detect_bpm(y, sr, mode)
|
| 1849 |
+
key, key_conf = _detect_key(y, sr, mode)
|
| 1850 |
+
signature, sig_conf = _detect_time_sig(y, sr, mode)
|
| 1851 |
+
confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
|
| 1852 |
+
else:
|
| 1853 |
+
# faf: all detectors on raw mix
|
| 1854 |
+
bpm, bpm_conf = _detect_bpm(y, sr, mode)
|
| 1855 |
+
key, key_conf = _detect_key(y, sr, mode)
|
| 1856 |
+
signature, sig_conf = _detect_time_sig(y, sr, mode)
|
| 1857 |
+
confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf}
|
| 1858 |
+
|
| 1859 |
+
finally:
|
| 1860 |
+
if tmp_dir is not None:
|
| 1861 |
+
try:
|
| 1862 |
+
shutil.rmtree(tmp_dir)
|
| 1863 |
+
except OSError as exc:
|
| 1864 |
+
logger.debug("Could not clean temp dir %s: %s", tmp_dir, exc)
|
| 1865 |
+
|
| 1866 |
title, artist = _extract_metadata_from_tags(audio_path)
|
| 1867 |
|
| 1868 |
# Build caption string for ACE-Step training
|
|
|
|
| 1888 |
"lyrics": lyrics,
|
| 1889 |
"title": title,
|
| 1890 |
"artist": artist,
|
| 1891 |
+
"confidence": confidence,
|
| 1892 |
}
|
| 1893 |
|
| 1894 |
+
logger.info("Auto-caption [%s] for %s: %s", mode, audio_path.name, caption)
|
| 1895 |
return result
|
| 1896 |
|
| 1897 |
|
|
|
|
| 1996 |
if sidecar and sidecar.get("caption"):
|
| 1997 |
caption = sidecar["caption"]
|
| 1998 |
lyrics = sidecar.get("lyrics", "[Instrumental]")
|
| 1999 |
+
logger.info("[Caption] %s: using existing sidecar", af.name)
|
| 2000 |
else:
|
| 2001 |
+
# Auto-select analysis mode based on dataset size
|
| 2002 |
+
if total <= 20:
|
| 2003 |
+
analysis_mode = "sas"
|
| 2004 |
+
elif total <= 100:
|
| 2005 |
+
analysis_mode = "mid"
|
| 2006 |
+
else:
|
| 2007 |
+
analysis_mode = "faf"
|
| 2008 |
+
|
| 2009 |
+
# Log mode selection with reasoning (first file only)
|
| 2010 |
+
if i == 0:
|
| 2011 |
+
_MODE_DESC = {
|
| 2012 |
+
"faf": "fast, ~3s/file",
|
| 2013 |
+
"mid": "balanced, ~12s/file",
|
| 2014 |
+
"sas": "best quality, ~30s/file on GPU, slower on CPU",
|
| 2015 |
+
}
|
| 2016 |
+
logger.info(
|
| 2017 |
+
"[Analysis] Mode auto-selected: '%s' (%s) "
|
| 2018 |
+
"for %d files (<=20: sas, 21-100: mid, 100+: faf)",
|
| 2019 |
+
analysis_mode, _MODE_DESC[analysis_mode], total,
|
| 2020 |
+
)
|
| 2021 |
+
if analysis_mode in ("mid", "sas") and device == "cpu":
|
| 2022 |
+
logger.warning(
|
| 2023 |
+
"[Analysis] Mode '%s' uses Demucs stem separation "
|
| 2024 |
+
"which is SLOW on CPU (~2-5 min/file). "
|
| 2025 |
+
"Total estimated time: ~%d-%d min for %d files. "
|
| 2026 |
+
"Use 'faf' mode or a GPU machine for faster processing.",
|
| 2027 |
+
analysis_mode,
|
| 2028 |
+
total * 2, total * 5, total,
|
| 2029 |
+
)
|
| 2030 |
+
|
| 2031 |
try:
|
| 2032 |
+
logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode)
|
| 2033 |
+
analysis = analyze_and_caption(
|
| 2034 |
+
str(af), mode=analysis_mode, device=device,
|
| 2035 |
+
)
|
| 2036 |
caption = analysis["caption"]
|
| 2037 |
lyrics = analysis.get("lyrics", "[Instrumental]")
|
| 2038 |
_write_caption_sidecar(af, analysis)
|
| 2039 |
+
logger.info("[Caption] %s: %s", af.name, caption)
|
| 2040 |
except Exception as exc:
|
| 2041 |
+
logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc)
|
| 2042 |
caption = af.stem
|
| 2043 |
lyrics = "[Instrumental]"
|
| 2044 |
text_prompt = caption
|