vgtomahawk commited on
Commit
dececb8
Β·
verified Β·
1 Parent(s): 78d2009

Upload train_sft_qwen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_qwen.py +71 -62
train_sft_qwen.py CHANGED
@@ -1,101 +1,110 @@
 
1
  # /// script
2
  # dependencies = [
3
- # "trl>=0.12.0",
4
- # "peft>=0.7.0",
5
- # "trackio",
6
- # "transformers>=4.40.0",
7
- # "datasets>=2.18.0",
8
- # "torch>=2.0.0",
9
  # ]
10
  # ///
11
 
12
  """
13
- SFT (Supervised Fine-Tuning) training script for Qwen/Qwen2.5-0.5B
14
- Uses TRL with LoRA, Trackio monitoring, and automatic Hub push
 
 
 
 
 
 
15
  """
16
 
 
17
  from datasets import load_dataset
18
  from peft import LoraConfig
19
  from trl import SFTTrainer, SFTConfig
20
- import trackio
21
 
22
- # Load a high-quality instruction dataset
 
23
  dataset = load_dataset("trl-lib/Capybara", split="train")
 
24
 
25
- # Create train/eval split for monitoring training progress
 
26
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
27
-
28
- # Configure LoRA for efficient fine-tuning
29
- peft_config = LoraConfig(
30
- r=16, # LoRA rank
31
- lora_alpha=32, # LoRA alpha scaling
32
- lora_dropout=0.05, # Dropout for regularization
33
- bias="none", # Don't train bias parameters
34
- task_type="CAUSAL_LM", # Causal language modeling
35
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Qwen attention layers
36
- )
37
-
38
- # Configure trainer
39
- training_config = SFTConfig(
40
- # Model and output
41
- output_dir="qwen-sft-capybara",
42
-
43
- # Hub configuration - CRITICAL for saving results
44
  push_to_hub=True,
45
- hub_model_id="qwen-sft-capybara-demo", # Will use format: username/qwen-sft-capybara-demo
46
- hub_strategy="every_save", # Push checkpoints during training
47
- hub_private_repo=False,
48
 
49
  # Training parameters
50
  num_train_epochs=3,
51
- per_device_train_batch_size=2,
52
- gradient_accumulation_steps=4, # Effective batch size: 2 * 4 = 8
 
53
 
54
- # Optimization
55
- learning_rate=2e-4,
56
- lr_scheduler_type="cosine",
57
- warmup_ratio=0.1,
 
58
 
59
  # Evaluation
60
  eval_strategy="steps",
61
- eval_steps=50,
62
- per_device_eval_batch_size=2,
63
 
64
- # Checkpointing
65
- save_strategy="steps",
66
- save_steps=100,
67
- save_total_limit=3, # Keep only last 3 checkpoints
68
 
69
- # Logging - Trackio integration
70
- logging_steps=10,
71
  report_to="trackio",
72
- run_name="qwen-0.5b-sft-demo",
73
-
74
- # Performance optimization
75
- bf16=True, # Use bfloat16 for better performance on modern GPUs
76
- gradient_checkpointing=True, # Reduce memory usage
77
 
78
- # Misc
79
- seed=42,
80
- dataloader_num_workers=4,
 
 
 
 
 
 
81
  )
82
 
83
  # Initialize trainer
 
84
  trainer = SFTTrainer(
85
  model="Qwen/Qwen2.5-0.5B",
86
- train_dataset=dataset_split["train"],
87
- eval_dataset=dataset_split["test"],
 
88
  peft_config=peft_config,
89
- args=training_config,
90
  )
91
 
92
- # Train the model
93
- print("Starting training...")
 
94
  trainer.train()
95
 
96
- # Final push to Hub
97
- print("Training complete! Pushing final model to Hub...")
 
98
  trainer.push_to_hub()
99
 
100
- print("βœ… Training complete and model saved to Hub!")
101
- print(f"Model available at: https://huggingface.co/{trainer.hub_model_id}")
 
 
 
1
+ #!/usr/bin/env python3
2
  # /// script
3
  # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # "trackio",
 
9
  # ]
10
  # ///
11
 
12
  """
13
+ SFT training script for Qwen/Qwen2.5-0.5B model.
14
+
15
+ This script demonstrates:
16
+ - Trackio integration for real-time monitoring
17
+ - LoRA/PEFT for efficient training
18
+ - Proper Hub saving configuration
19
+ - Train/eval split for monitoring progress
20
+ - Optimized training parameters for small model testing
21
  """
22
 
23
+ import trackio
24
  from datasets import load_dataset
25
  from peft import LoraConfig
26
  from trl import SFTTrainer, SFTConfig
 
27
 
28
+ # Load dataset
29
+ print("πŸ“¦ Loading dataset...")
30
  dataset = load_dataset("trl-lib/Capybara", split="train")
31
+ print(f"βœ… Dataset loaded: {len(dataset)} examples")
32
 
33
+ # Create train/eval split for monitoring
34
+ print("πŸ”€ Creating train/eval split...")
35
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
36
+ train_dataset = dataset_split["train"]
37
+ eval_dataset = dataset_split["test"]
38
+ print(f" Train: {len(train_dataset)} examples")
39
+ print(f" Eval: {len(eval_dataset)} examples")
40
+
41
+ # Training configuration
42
+ print("βš™οΈ Configuring training parameters...")
43
+ config = SFTConfig(
44
+ # CRITICAL: Hub settings - Save model to Hugging Face Hub
45
+ output_dir="qwen-0.5b-sft-capybara",
 
 
 
 
 
 
 
46
  push_to_hub=True,
47
+ hub_model_id="vgtomahawk/qwen-0.5b-sft-capybara",
48
+ hub_strategy="every_save", # Push checkpoints to Hub
 
49
 
50
  # Training parameters
51
  num_train_epochs=3,
52
+ per_device_train_batch_size=4,
53
+ gradient_accumulation_steps=4, # Effective batch size = 4 * 4 = 16
54
+ learning_rate=2e-5,
55
 
56
+ # Logging & checkpointing
57
+ logging_steps=10,
58
+ save_strategy="steps",
59
+ save_steps=100,
60
+ save_total_limit=2, # Keep only last 2 checkpoints
61
 
62
  # Evaluation
63
  eval_strategy="steps",
64
+ eval_steps=100,
 
65
 
66
+ # Optimization
67
+ warmup_ratio=0.1,
68
+ lr_scheduler_type="cosine",
 
69
 
70
+ # Monitoring with Trackio
 
71
  report_to="trackio",
72
+ project="qwen-sft-demo",
73
+ run_name="qwen-0.5b-baseline",
74
+ )
 
 
75
 
76
+ # LoRA configuration for efficient training
77
+ print("πŸ”§ Setting up LoRA configuration...")
78
+ peft_config = LoraConfig(
79
+ r=16,
80
+ lora_alpha=32,
81
+ lora_dropout=0.05,
82
+ bias="none",
83
+ task_type="CAUSAL_LM",
84
+ target_modules=["q_proj", "v_proj"],
85
  )
86
 
87
  # Initialize trainer
88
+ print("🎯 Initializing SFT trainer...")
89
  trainer = SFTTrainer(
90
  model="Qwen/Qwen2.5-0.5B",
91
+ train_dataset=train_dataset,
92
+ eval_dataset=eval_dataset,
93
+ args=config,
94
  peft_config=peft_config,
 
95
  )
96
 
97
+ # Start training
98
+ print("πŸš€ Starting training...")
99
+ print("=" * 60)
100
  trainer.train()
101
 
102
+ # Push final model to Hub
103
+ print("=" * 60)
104
+ print("πŸ’Ύ Pushing final model to Hub...")
105
  trainer.push_to_hub()
106
 
107
+ # Complete
108
+ print("βœ… Training complete!")
109
+ print(f"πŸ“Š Model available at: https://huggingface.co/vgtomahawk/qwen-0.5b-sft-capybara")
110
+ print(f"πŸ“ˆ View training metrics at: https://huggingface.co/spaces/vgtomahawk/trackio")