lhallee commited on
Commit
95c09ea
·
verified ·
1 Parent(s): e2ff200

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +22 -6
modeling_esm_plusplus.py CHANGED
@@ -1116,14 +1116,26 @@ class RotaryEmbedding(torch.nn.Module):
1116
  self.register_buffer("scale", scale)
1117
 
1118
  def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
1119
- """Compute inverse frequency bands."""
1120
- return 1 / (
 
 
 
 
 
 
 
 
 
1121
  self.base
1122
  ** (
1123
- torch.arange(0, self.dim, 2, device=device, dtype=torch.float32)
1124
  / self.dim
1125
  )
1126
  )
 
 
 
1127
 
1128
  def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
1129
  """Update the cached cosine and sine values."""
@@ -1174,9 +1186,13 @@ class RotaryEmbedding(torch.nn.Module):
1174
  Returns:
1175
  Tuple of rotated query and key tensors
1176
  """
1177
- assert self._inv_freq_compute_device is not None, "Rotary inv_freq compute device should be set after initialization."
1178
- if self._inv_freq_compute_device != q.device:
1179
- self.reset_parameters()
 
 
 
 
1180
  self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
1181
  assert self._cos_cached is not None
1182
  assert self._sin_cached is not None
 
1116
  self.register_buffer("scale", scale)
1117
 
1118
  def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
1119
+ """Compute inverse frequency bands.
1120
+
1121
+ Always computes on CPU then moves to the requested device. This matches
1122
+ native EvolutionaryScale ESMC, which computes inv_freq on CPU at
1123
+ `__init__` and migrates via `.to(device)`. Computing directly on GPU
1124
+ gives a ~3.7e-9 bit-level difference in inv_freq (fp32 transcendental
1125
+ precision differs between CPU and GPU), which compounds through the 30
1126
+ attention layers to ~1e-3 mse divergence from native at
1127
+ `hidden_states[-2]`. See testing/parity_debug_rotary.py.
1128
+ """
1129
+ cpu_inv_freq = 1 / (
1130
  self.base
1131
  ** (
1132
+ torch.arange(0, self.dim, 2, device="cpu", dtype=torch.float32)
1133
  / self.dim
1134
  )
1135
  )
1136
+ if device is not None and torch.device(device).type != "cpu":
1137
+ return cpu_inv_freq.to(device)
1138
+ return cpu_inv_freq
1139
 
1140
  def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
1141
  """Update the cached cosine and sine values."""
 
1186
  Returns:
1187
  Tuple of rotated query and key tensors
1188
  """
1189
+ # NOTE: do NOT recompute inv_freq here if device has changed. The native
1190
+ # ESMC implementation computes inv_freq once on CPU at __init__ and
1191
+ # relies on PyTorch's `.to(device)` to migrate the buffer. Recomputing
1192
+ # the values directly on GPU gives a ~3.7e-9 bit-level difference vs the
1193
+ # CPU-computed-then-moved values due to fp32 transcendental precision,
1194
+ # which compounds through 30 attention layers to ~1e-3 mse divergence
1195
+ # from native at `hidden_states[-2]`. See testing/parity_debug_rotary.py.
1196
  self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
1197
  assert self._cos_cached is not None
1198
  assert self._sin_cached is not None