Yuchan commited on
Commit
a638654
Β·
verified Β·
1 Parent(s): 19949b0

Update AlphaS2S.py

Browse files
Files changed (1) hide show
  1. AlphaS2S.py +46 -70
AlphaS2S.py CHANGED
@@ -81,87 +81,61 @@ def ids_to_text(ids):
81
  return sp.decode(ids)
82
 
83
  # =======================
84
- # 2) 데이터셋 생성 ν•¨μˆ˜ (κΈ°μ‘΄ μ½”λ“œμ™€ 동일)
85
  # =======================
86
-
87
  def jsonl_stream(file_path):
88
  with open(file_path, "r", encoding="utf-8") as f:
89
  for line in f:
90
  data = json.loads(line)
91
- conversations = data.get("conversations", [])
92
- for i in range(0, len(conversations) - 1, 2):
93
- human_msg = conversations[i]
94
- gpt_msg = conversations[i + 1]
95
- if human_msg.get("from") != "human" or gpt_msg.get("from") != "gpt":
96
- continue
97
-
98
- prompt = human_msg.get("value", "").strip()
99
- response = gpt_msg.get("value", "").strip()
100
- full = f"<start> {prompt} <sep> {response} <end>"
101
- if "<sep>" not in full:
102
- continue
103
-
104
- sep_index = full.index("<sep>")
105
-
106
- # 인코더 μž…λ ₯은 <start> ν”„λ‘¬ν”„νŠΈ <sep> λΆ€λΆ„, 디코더 μž…λ ₯은 <sep> 응닡 <end> λΆ€λΆ„
107
- # (Unified Input: 인코더/디코더 μž…λ ₯ λͺ¨λ‘ full_input을 μ‚¬μš©)
108
- input_text = full
109
-
110
- # νƒ€κ²Ÿ μ‹œν€€μŠ€λŠ” 응닡 μ‹œμž‘ λΆ€λΆ„λΆ€ν„° <end>κΉŒμ§€μ΄λ©°, μž…λ ₯보닀 ν•œ μΉΈ μ‹œν”„νŠΈλ¨
111
- # μ—¬κΈ°μ„œ target_textλŠ” 응닡 λΆ€λΆ„λ§Œ μΆ”μΆœν•˜μ—¬ νƒ€κ²Ÿ λ§ˆμŠ€ν‚Ήμ— μ‚¬μš©λ©λ‹ˆλ‹€.
112
- target_text_raw = full[sep_index + len("<sep>"):]
113
-
114
- input_ids = text_to_ids(input_text) # 전체 μ‹œν€€μŠ€
115
- target_ids_raw = text_to_ids(target_text_raw) # 응닡 λΆ€λΆ„λ§Œ
116
-
117
- # 길이 처리 및 λ§ˆμŠ€ν‚Ή λ‘œμ§μ€ κΈ°μ‘΄ μ½”λ“œλ₯Ό κ·ΈλŒ€λ‘œ μœ μ§€
118
- full_input = input_ids[:max_len]
119
- target_ids = target_ids_raw[:max_len - len(input_ids)]
120
-
121
- available_len = max_len - len(input_ids)
122
-
123
- if available_len <= 0:
124
- input_ids = input_ids[-max_len:]
125
- target_ids = []
126
- target_mask = [0] * len(input_ids)
127
- else:
128
- target_ids = target_ids[:available_len]
129
- target_mask = [0] * len(input_ids) + [1] * len(target_ids)
130
-
131
- full_input = input_ids + target_ids
132
- pad_len = max_len - len(full_input)
133
- full_input += [pad_id] * pad_len
134
- target_mask += [0] * pad_len
135
-
136
- # νƒ€κ²Ÿ μ‹œν€€μŠ€λŠ” μž…λ ₯ μ‹œν€€μŠ€λ³΄λ‹€ ν•œ μΉΈ μ‹œν”„νŠΈλœ ν˜•νƒœ
137
- target_seq = full_input[1:] + [end_id]
138
- target_seq = target_seq[:max_len]
139
-
140
- # λ§ˆμŠ€ν‚Ήλœ νƒ€κ²Ÿ 생성 (ν”„λ‘¬ν”„νŠΈ/νŒ¨λ”© 뢀뢄은 pad_id둜 λŒ€μ²΄)
141
- masked_target = [
142
- t if m == 1 else pad_id
143
- for t, m in zip(target_seq, target_mask)
144
- ]
145
-
146
- # AlphaS2SλŠ” 인코더/디코더 μž…λ ₯으둜 같은 μ‹œν€€μŠ€λ₯Ό μ‚¬μš©
147
- # μž…λ ₯ μ‹œν€€μŠ€ = full_input
148
- # νƒ€κ²Ÿ μ‹œν€€μŠ€ = masked_target
149
- yield (
150
- tf.convert_to_tensor(full_input, dtype=tf.int32),
151
- tf.convert_to_tensor(full_input, dtype=tf.int32), # 디코더 μž…λ ₯도 λ™μΌν•˜κ²Œ 전달
152
- tf.convert_to_tensor(masked_target, dtype=tf.int32) # μ‹€μ œ νƒ€κ²Ÿ
153
- )
154
 
 
 
 
155
  dataset = tf.data.Dataset.from_generator(
156
  lambda: jsonl_stream(DATA_PATH),
157
  output_signature=(
158
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
159
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
160
- tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
161
- ),
162
  )
163
 
164
- # ν•™μŠ΅μ„ μœ„ν•΄ λ”•μ…”λ„ˆλ¦¬ ν˜•νƒœλ‘œ λ§΅ν•‘
165
  def map_fn(enc_input, dec_input, dec_target):
166
  return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
167
 
@@ -171,6 +145,8 @@ dataset = dataset.shuffle(1000, seed=SEED).batch(batch_size, drop_remainder=True
171
  with strategy.scope():
172
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
173
 
 
 
174
  # =======================
175
  # 3) λͺ¨λΈ λ ˆμ΄μ–΄ (κΈ°μ‘΄ μ½”λ“œ μœ μ§€)
176
  # =======================
 
81
  return sp.decode(ids)
82
 
83
  # =======================
84
+ # JSONL β†’ TF Dataset λ‘œλ“œ (ID 레벨 특수 토큰 포함)
85
  # =======================
 
86
  def jsonl_stream(file_path):
87
  with open(file_path, "r", encoding="utf-8") as f:
88
  for line in f:
89
  data = json.loads(line)
90
+ context = data["context"]
91
+ prompt = data["prompt"]
92
+ answer = data["answer"]
93
+
94
+ # =======================
95
+ # Encoder input: ID λ ˆλ²¨μ—μ„œ 특수 토큰 λͺ…μ‹œ
96
+ # =======================
97
+ enc_ids = [context_s_id] + text_to_ids(context) + [context_e_id] + \
98
+ [user_s_id] + text_to_ids(prompt) + [user_e_id]
99
+ enc_ids = enc_ids[:max_len] # max_len μ œν•œ
100
+
101
+ # =======================
102
+ # Decoder input: <sos> + answer
103
+ # =======================
104
+ dec_input_ids = [start_id] + text_to_ids(answer)
105
+ dec_input_ids = dec_input_ids[:max_len]
106
+
107
+ # =======================
108
+ # Target: answer + <eos>
109
+ # =======================
110
+ target_ids = text_to_ids(answer) + [end_id]
111
+ target_ids = target_ids[:max_len]
112
+
113
+ # =======================
114
+ # Padding
115
+ # =======================
116
+ enc_ids += [pad_id] * (max_len - len(enc_ids))
117
+ dec_input_ids += [pad_id] * (max_len - len(dec_input_ids))
118
+ target_ids += [pad_id] * (max_len - len(target_ids))
119
+
120
+ yield (
121
+ tf.convert_to_tensor(enc_ids, dtype=tf.int32),
122
+ tf.convert_to_tensor(dec_input_ids, dtype=tf.int32),
123
+ tf.convert_to_tensor(target_ids, dtype=tf.int32),
124
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # =======================
127
+ # TF Dataset 생성
128
+ # =======================
129
  dataset = tf.data.Dataset.from_generator(
130
  lambda: jsonl_stream(DATA_PATH),
131
  output_signature=(
132
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # enc_inputs
133
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # dec_inputs
134
+ tf.TensorSpec(shape=(max_len,), dtype=tf.int32), # target
135
+ )
136
  )
137
 
138
+ # ν•™μŠ΅μ„ μœ„ν•΄ λ”•μ…”λ„ˆλ¦¬ ν˜•νƒœλ‘œ λ§€ν•‘
139
  def map_fn(enc_input, dec_input, dec_target):
140
  return {"enc_inputs": enc_input, "dec_inputs": dec_input}, dec_target
141
 
 
145
  with strategy.scope():
146
  dist_dataset = strategy.experimental_distribute_dataset(dataset)
147
 
148
+ print("βœ… ID 레벨 특수 토큰 적용 Dataset λ‘œλ“œ μ™„λ£Œ:", dist_dataset)
149
+
150
  # =======================
151
  # 3) λͺ¨λΈ λ ˆμ΄μ–΄ (κΈ°μ‘΄ μ½”λ“œ μœ μ§€)
152
  # =======================