| from datasets import load_dataset, Audio |
|
|
| N_PROC = None |
|
|
| ds = load_dataset("JacobLinCool/taiko-1000-parsed") |
| ds = ds.remove_columns(["tja", "hard", "normal", "easy", "ura"]) |
|
|
|
|
| def filter_out_broken(example): |
| try: |
| example["audio"]["array"] |
| return True |
| except: |
| return False |
|
|
|
|
| ds = ds.filter(filter_out_broken, num_proc=N_PROC, batch_size=32, writer_batch_size=32) |
| ds = ds.cast_column("audio", Audio(sampling_rate=16000)) |
|
|
|
|
| def build_beat_and_downbeat_labels(example): |
| """ |
| Extract beat and downbeat times from the chart segments. |
| |
| - Downbeats: First beat of each measure (segment timestamp) |
| - Beats: All beats within each measure based on time signature |
| |
| Returns lists of times in seconds. |
| """ |
| title = example["metadata"]["TITLE"] |
| segments = example["oni"]["segments"] |
|
|
| beats = [] |
| downbeats = [] |
|
|
| for i, segment in enumerate(segments): |
| seg_timestamp = segment["timestamp"] |
| measure_num = segment["measure_num"] |
| measure_den = segment["measure_den"] |
| notes = segment["notes"] |
|
|
| |
| downbeats.append(seg_timestamp) |
|
|
| |
| bpm = None |
| if notes: |
| bpm = notes[0]["bpm"] |
| else: |
| |
| for j in range(i + 1, len(segments)): |
| if segments[j]["notes"]: |
| bpm = segments[j]["notes"][0]["bpm"] |
| break |
|
|
| if bpm is None or bpm <= 0: |
| bpm = 120.0 |
|
|
| |
| |
| beat_duration = (60.0 / bpm) * (4.0 / measure_den) |
|
|
| |
| for beat_idx in range(measure_num): |
| beat_time = seg_timestamp + beat_idx * beat_duration |
| beats.append(beat_time) |
|
|
| |
| beats = sorted(set(beats)) |
| downbeats = sorted(set(downbeats)) |
|
|
| return { |
| "title": title, |
| "beats": beats, |
| "downbeats": downbeats, |
| } |
|
|
|
|
| ds = ds.map( |
| build_beat_and_downbeat_labels, |
| num_proc=N_PROC, |
| batch_size=32, |
| writer_batch_size=32, |
| remove_columns=["oni", "metadata"], |
| ) |
|
|
| ds = ds.with_format("torch") |
|
|
| if __name__ == "__main__": |
| print(ds) |
| print(ds["train"].features) |
|
|