Nekochu commited on
Commit
4619f39
·
1 Parent(s): d3618ec

add fast captioning module (CLAP + faster-whisper + Silero VAD), update deps

Browse files
Files changed (2) hide show
  1. Dockerfile +3 -1
  2. caption_fast.py +260 -0
Dockerfile CHANGED
@@ -78,7 +78,8 @@ RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/w
78
  "gradio[mcp]>=6.0.0,<7.0.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
- einops vector_quantize_pytorch librosa mutagen demucs-infer
 
82
 
83
  # Clone ACE-Step repo for training module
84
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
@@ -92,6 +93,7 @@ RUN python3 -c "from huggingface_hub import snapshot_download; \
92
  # Copy application files
93
  COPY app.py /app/app.py
94
  COPY train_engine.py /app/train_engine.py
 
95
  COPY start.sh /app/start.sh
96
  RUN chmod +x /app/start.sh
97
 
 
78
  "gradio[mcp]>=6.0.0,<7.0.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
+ einops vector_quantize_pytorch librosa mutagen demucs-infer \
82
+ faster-whisper silero-vad
83
 
84
  # Clone ACE-Step repo for training module
85
  RUN git clone --depth 1 https://github.com/ace-step/ACE-Step-1.5 /app/ace-step-source
 
93
  # Copy application files
94
  COPY app.py /app/app.py
95
  COPY train_engine.py /app/train_engine.py
96
+ COPY caption_fast.py /app/caption_fast.py
97
  COPY start.sh /app/start.sh
98
  RUN chmod +x /app/start.sh
99
 
caption_fast.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Fast audio captioning: CLAP tags + Silero VAD + faster-whisper lyrics.
2
+
3
+ Provides mood/genre/instrument tagging via CLAP zero-shot classification,
4
+ speech detection via Silero VAD, and lyrics extraction via faster-whisper.
5
+ All models run on CPU. Total: ~3-5 min per file.
6
+
7
+ Usage:
8
+ from caption_fast import caption_audio
9
+ result = caption_audio("song.mp3")
10
+ # {"caption": "Pop, Energetic, Guitar, Melodic, Upbeat",
11
+ # "lyrics": "[Verse]\nSome lyrics here...",
12
+ # "bpm": 120, "key": "C major", "signature": "4/4",
13
+ # "tags": ["Pop", "Energetic", "Guitar", ...]}
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import logging
20
+ import os
21
+ from pathlib import Path
22
+ from typing import Dict, List, Optional
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Tag list for CLAP zero-shot classification (from clap-interrogator)
27
+ TAGS = [
28
+ "Fast", "Slow", "Upbeat", "Downbeat", "Moderate",
29
+ "Happy", "Sad", "Energetic", "Relaxed", "Melancholic", "Uplifting",
30
+ "Aggressive", "Peaceful", "Romantic", "Dark", "Light", "Mysterious",
31
+ "Dreamy", "Somber", "Hopeful", "Gloomy", "Cheerful", "Reflective",
32
+ "Nostalgic", "Tense", "Calm",
33
+ "Piano", "Guitar", "Violin", "Drums", "Bass", "Synthesizer",
34
+ "Saxophone", "Trumpet", "Flute", "Cello", "Clarinet", "Harp",
35
+ "Percussion", "Organ", "Accordion", "Electronic", "Acoustic",
36
+ "Electric Guitar", "Acoustic Guitar", "Synth Pad", "Keyboards",
37
+ "Rock", "Pop", "Jazz", "Classical", "Electronic", "Folk", "Hip-Hop",
38
+ "Blues", "Ambient", "Country", "Reggae", "Funk", "Soul", "Metal",
39
+ "Dance", "Disco", "House", "Techno", "Trance", "Soundtrack", "World",
40
+ "Indie", "Alternative", "R&B", "EDM", "Chillwave", "Dubstep",
41
+ "Lo-fi Hip-Hop", "Drum and Bass", "Jazz Fusion", "Neo-Soul", "Trap",
42
+ "K-Pop", "J-Pop", "Reggaeton", "Punk", "Grunge",
43
+ "Bright", "Warm", "Smooth", "Distorted", "Clean", "Lo-fi",
44
+ "Layered", "Minimalist", "Cinematic", "Atmospheric", "Ethereal",
45
+ "Groovy", "Rhythmic", "Melodic", "Harmonic",
46
+ "Live", "Studio", "Instrumental",
47
+ ]
48
+
49
+ _clap_model = None
50
+ _clap_processor = None
51
+ _whisper_model = None
52
+ _vad_model = None
53
+
54
+
55
+ def _load_clap():
56
+ global _clap_model, _clap_processor
57
+ if _clap_model is not None:
58
+ return _clap_model, _clap_processor
59
+ from transformers import ClapModel, ClapProcessor
60
+ logger.info("[CLAP] Loading laion/larger_clap_music...")
61
+ _clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music")
62
+ _clap_model = ClapModel.from_pretrained("laion/larger_clap_music")
63
+ _clap_model.eval()
64
+ logger.info("[CLAP] Ready (~780MB)")
65
+ return _clap_model, _clap_processor
66
+
67
+
68
+ def _load_whisper():
69
+ global _whisper_model
70
+ if _whisper_model is not None:
71
+ return _whisper_model
72
+ from faster_whisper import WhisperModel
73
+ logger.info("[Whisper] Loading large-v3-turbo (int8, CPU)...")
74
+ _whisper_model = WhisperModel(
75
+ "large-v3-turbo",
76
+ device="cpu",
77
+ compute_type="int8",
78
+ )
79
+ logger.info("[Whisper] Ready (~1.5GB)")
80
+ return _whisper_model
81
+
82
+
83
+ def _load_vad():
84
+ global _vad_model
85
+ if _vad_model is not None:
86
+ return _vad_model
87
+ import torch
88
+ logger.info("[VAD] Loading Silero VAD...")
89
+ _vad_model, _vad_utils = torch.hub.load(
90
+ repo_or_dir='snakers4/silero-vad',
91
+ model='silero_vad',
92
+ onnx=True,
93
+ trust_repo=True,
94
+ )
95
+ logger.info("[VAD] Ready (~2MB)")
96
+ return _vad_model
97
+
98
+
99
+ def unload_caption_models():
100
+ """Free all captioning models from memory."""
101
+ global _clap_model, _clap_processor, _whisper_model, _vad_model
102
+ import gc
103
+ _clap_model = None
104
+ _clap_processor = None
105
+ _whisper_model = None
106
+ _vad_model = None
107
+ gc.collect()
108
+ logger.info("[Caption] All models unloaded")
109
+
110
+
111
+ def tag_audio(audio_path: str, top_n: int = 10) -> List[str]:
112
+ """Get top-N CLAP tags for an audio file."""
113
+ import librosa
114
+ import torch
115
+
116
+ model, processor = _load_clap()
117
+ audio, sr = librosa.load(audio_path, sr=48000, mono=True)
118
+
119
+ inputs = processor(
120
+ text=TAGS,
121
+ audios=[audio],
122
+ sampling_rate=48000,
123
+ return_tensors="pt",
124
+ padding=True,
125
+ )
126
+
127
+ with torch.no_grad():
128
+ outputs = model(**inputs)
129
+
130
+ probs = outputs.logits_per_audio.softmax(dim=-1)
131
+ top_probs, top_indices = probs.topk(top_n, dim=1)
132
+ return [TAGS[i] for i in top_indices[0].tolist()]
133
+
134
+
135
+ def detect_speech(audio_path: str, threshold: float = 5.0) -> bool:
136
+ """Check if audio contains speech using Silero VAD.
137
+ Returns True if speech detected for more than `threshold` seconds.
138
+ """
139
+ import torch
140
+ import torchaudio
141
+
142
+ vad = _load_vad()
143
+ wav, sr = torchaudio.load(audio_path)
144
+ if wav.shape[0] > 1:
145
+ wav = wav.mean(dim=0, keepdim=True)
146
+ if sr != 16000:
147
+ wav = torchaudio.functional.resample(wav, sr, 16000)
148
+
149
+ speech_timestamps = []
150
+ window_size = 512
151
+ for i in range(0, wav.shape[1], window_size):
152
+ chunk = wav[0, i:i + window_size]
153
+ if len(chunk) < window_size:
154
+ break
155
+ prob = vad(chunk, 16000).item()
156
+ if prob > 0.5:
157
+ speech_timestamps.append(i / 16000)
158
+
159
+ speech_duration = len(speech_timestamps) * (window_size / 16000)
160
+ logger.info("[VAD] Speech: %.1fs detected in %s", speech_duration, os.path.basename(audio_path))
161
+ return speech_duration > threshold
162
+
163
+
164
+ def transcribe_lyrics(audio_path: str) -> str:
165
+ """Extract lyrics from audio using faster-whisper."""
166
+ model = _load_whisper()
167
+
168
+ segments, info = model.transcribe(
169
+ audio_path,
170
+ language=None,
171
+ beam_size=5,
172
+ vad_filter=True,
173
+ )
174
+
175
+ lines = []
176
+ for segment in segments:
177
+ text = segment.text.strip()
178
+ if text:
179
+ lines.append(text)
180
+
181
+ lyrics = "\n".join(lines)
182
+ if not lyrics.strip():
183
+ return "[Instrumental]"
184
+
185
+ logger.info("[Whisper] Transcribed %d lines (lang=%s, prob=%.2f)",
186
+ len(lines), info.language, info.language_probability)
187
+ return lyrics
188
+
189
+
190
+ def get_bpm_key(audio_path: str) -> Dict[str, str]:
191
+ """Get BPM and key via librosa."""
192
+ import librosa
193
+ import numpy as np
194
+
195
+ y, sr = librosa.load(audio_path, sr=None, mono=True)
196
+
197
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
198
+ bpm = int(round(float(tempo.item() if hasattr(tempo, 'item') else tempo)))
199
+
200
+ chroma = librosa.feature.chroma_cens(y=y, sr=sr)
201
+ chroma_avg = np.mean(chroma, axis=1)
202
+ keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
203
+ major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
204
+ minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])
205
+
206
+ best_corr = -1
207
+ best_key = "C major"
208
+ for i in range(12):
209
+ maj_corr = float(np.corrcoef(np.roll(major_profile, i), chroma_avg)[0, 1])
210
+ min_corr = float(np.corrcoef(np.roll(minor_profile, i), chroma_avg)[0, 1])
211
+ if maj_corr > best_corr:
212
+ best_corr = maj_corr
213
+ best_key = f"{keys[i]} major"
214
+ if min_corr > best_corr:
215
+ best_corr = min_corr
216
+ best_key = f"{keys[i]} minor"
217
+
218
+ return {"bpm": str(bpm), "key": best_key, "signature": "4/4"}
219
+
220
+
221
+ def caption_audio(
222
+ audio_path: str,
223
+ top_n: int = 10,
224
+ extract_lyrics: bool = True,
225
+ speech_threshold: float = 5.0,
226
+ ) -> Dict[str, str]:
227
+ """Full fast captioning pipeline for one audio file.
228
+
229
+ Returns dict with: caption, lyrics, bpm, key, signature, tags
230
+ """
231
+ fname = os.path.basename(audio_path)
232
+ logger.info("[Caption] Processing %s...", fname)
233
+
234
+ # 1. CLAP tags (mood, genre, instruments)
235
+ tags = tag_audio(audio_path, top_n=top_n)
236
+ caption = ", ".join(tags)
237
+ logger.info("[Caption] %s: tags=%s", fname, caption)
238
+
239
+ # 2. BPM + key via librosa
240
+ bpm_key = get_bpm_key(audio_path)
241
+ logger.info("[Caption] %s: BPM=%s, key=%s", fname, bpm_key["bpm"], bpm_key["key"])
242
+
243
+ # 3. Speech detection + lyrics
244
+ lyrics = "[Instrumental]"
245
+ if extract_lyrics:
246
+ has_speech = detect_speech(audio_path, threshold=speech_threshold)
247
+ if has_speech:
248
+ logger.info("[Caption] %s: speech detected, transcribing lyrics...", fname)
249
+ lyrics = transcribe_lyrics(audio_path)
250
+ else:
251
+ logger.info("[Caption] %s: no speech, marking instrumental", fname)
252
+
253
+ return {
254
+ "caption": caption,
255
+ "lyrics": lyrics,
256
+ "bpm": bpm_key["bpm"],
257
+ "key": bpm_key["key"],
258
+ "signature": bpm_key["signature"],
259
+ "tags": tags,
260
+ }