Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +22 -6
modeling_esm_plusplus.py
CHANGED
|
@@ -1116,14 +1116,26 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 1116 |
self.register_buffer("scale", scale)
|
| 1117 |
|
| 1118 |
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 1119 |
-
"""Compute inverse frequency bands.
|
| 1120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1121 |
self.base
|
| 1122 |
** (
|
| 1123 |
-
torch.arange(0, self.dim, 2, device=
|
| 1124 |
/ self.dim
|
| 1125 |
)
|
| 1126 |
)
|
|
|
|
|
|
|
|
|
|
| 1127 |
|
| 1128 |
def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
| 1129 |
"""Update the cached cosine and sine values."""
|
|
@@ -1174,9 +1186,13 @@ class RotaryEmbedding(torch.nn.Module):
|
|
| 1174 |
Returns:
|
| 1175 |
Tuple of rotated query and key tensors
|
| 1176 |
"""
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1180 |
self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
|
| 1181 |
assert self._cos_cached is not None
|
| 1182 |
assert self._sin_cached is not None
|
|
|
|
| 1116 |
self.register_buffer("scale", scale)
|
| 1117 |
|
| 1118 |
def _compute_inv_freq(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
| 1119 |
+
"""Compute inverse frequency bands.
|
| 1120 |
+
|
| 1121 |
+
Always computes on CPU then moves to the requested device. This matches
|
| 1122 |
+
native EvolutionaryScale ESMC, which computes inv_freq on CPU at
|
| 1123 |
+
`__init__` and migrates via `.to(device)`. Computing directly on GPU
|
| 1124 |
+
gives a ~3.7e-9 bit-level difference in inv_freq (fp32 transcendental
|
| 1125 |
+
precision differs between CPU and GPU), which compounds through the 30
|
| 1126 |
+
attention layers to ~1e-3 mse divergence from native at
|
| 1127 |
+
`hidden_states[-2]`. See testing/parity_debug_rotary.py.
|
| 1128 |
+
"""
|
| 1129 |
+
cpu_inv_freq = 1 / (
|
| 1130 |
self.base
|
| 1131 |
** (
|
| 1132 |
+
torch.arange(0, self.dim, 2, device="cpu", dtype=torch.float32)
|
| 1133 |
/ self.dim
|
| 1134 |
)
|
| 1135 |
)
|
| 1136 |
+
if device is not None and torch.device(device).type != "cpu":
|
| 1137 |
+
return cpu_inv_freq.to(device)
|
| 1138 |
+
return cpu_inv_freq
|
| 1139 |
|
| 1140 |
def _update_cos_sin_cache(self, seqlen: int, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
| 1141 |
"""Update the cached cosine and sine values."""
|
|
|
|
| 1186 |
Returns:
|
| 1187 |
Tuple of rotated query and key tensors
|
| 1188 |
"""
|
| 1189 |
+
# NOTE: do NOT recompute inv_freq here if device has changed. The native
|
| 1190 |
+
# ESMC implementation computes inv_freq once on CPU at __init__ and
|
| 1191 |
+
# relies on PyTorch's `.to(device)` to migrate the buffer. Recomputing
|
| 1192 |
+
# the values directly on GPU gives a ~3.7e-9 bit-level difference vs the
|
| 1193 |
+
# CPU-computed-then-moved values due to fp32 transcendental precision,
|
| 1194 |
+
# which compounds through 30 attention layers to ~1e-3 mse divergence
|
| 1195 |
+
# from native at `hidden_states[-2]`. See testing/parity_debug_rotary.py.
|
| 1196 |
self._update_cos_sin_cache(q.shape[1], device=q.device, dtype=q.dtype)
|
| 1197 |
assert self._cos_cached is not None
|
| 1198 |
assert self._sin_cached is not None
|