Instructions to use recursionpharma/OpenPhenom with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use recursionpharma/OpenPhenom with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="recursionpharma/OpenPhenom", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("recursionpharma/OpenPhenom", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| # © Recursion Pharmaceuticals 2024 | |
| from functools import partial | |
| from typing import Tuple, Union | |
| import torch | |
| import torch.nn as nn | |
| from timm.models.helpers import checkpoint_seq | |
| from timm.models.vision_transformer import Block, Mlp, VisionTransformer | |
| from .masking import transformer_random_masking | |
| from .vit import channel_agnostic_vit | |
| # If interested in training new MAEs, combine an encoder and decoder into a new module, and you should | |
| # leverage the flattening and unflattening utilities as needed from mae_utils.py. | |
| # Be sure to use an encoder-decoder Linear projection layer to match encoder dims with decoder dimensions. | |
| # As described in the paper, images are self-standardized at the start. | |
| class SelfStandardize(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.self_standardize = nn.LazyInstanceNorm2d( | |
| affine=False, track_running_stats=False | |
| ) | |
| def forward(self, pixels: torch.Tensor) -> torch.Tensor: | |
| x = pixels.float() / 255.0 | |
| return self.self_standardize(x) | |
| class MAEEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| vit_backbone: VisionTransformer, | |
| max_in_chans: int = 6, | |
| channel_agnostic: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| if channel_agnostic: | |
| self.vit_backbone = channel_agnostic_vit( | |
| vit_backbone, max_in_chans=max_in_chans | |
| ) | |
| else: | |
| self.vit_backbone = vit_backbone | |
| self.max_in_chans = max_in_chans | |
| self.channel_agnostic = channel_agnostic | |
| def embed_dim(self) -> int: | |
| return int(self.vit_backbone.embed_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.vit_backbone.forward_features(x) | |
| x = self.vit_backbone.forward_head(x) | |
| return x # type: ignore[no-any-return] | |
| def forward_masked( | |
| self, | |
| x: torch.Tensor, | |
| mask_ratio: float, | |
| constant_noise: Union[torch.Tensor, None] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| x = self.vit_backbone.patch_embed(x) | |
| x = self.vit_backbone._pos_embed(x) # adds class token | |
| x_ = x[:, 1:, :] # no class token | |
| x_, mask, ind_restore = transformer_random_masking( | |
| x_, mask_ratio, constant_noise | |
| ) | |
| x = torch.cat([x[:, :1, :], x_], dim=1) # add class token | |
| x = self.vit_backbone.norm_pre(x) | |
| if self.vit_backbone.grad_checkpointing and not torch.jit.is_scripting(): | |
| x = checkpoint_seq(self.vit_backbone.blocks, x) | |
| else: | |
| x = self.vit_backbone.blocks(x) | |
| x = self.vit_backbone.norm(x) | |
| return x, mask, ind_restore | |
| class MAEDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int = 512, | |
| depth: int = 8, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4, | |
| qkv_bias: bool = True, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] | |
| ) -> None: | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.pos_embeddings = None # to be overwritten by MAE class | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.blocks = nn.Sequential( | |
| *[ | |
| Block( | |
| embed_dim, | |
| num_heads, | |
| mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| self.norm = norm_layer(embed_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = x + self.pos_embeddings | |
| x = self.blocks(x) | |
| x = self.norm(x) | |
| return x # type: ignore[no-any-return] | |
| def forward_masked( | |
| self, x: torch.Tensor, ind_restore: torch.Tensor | |
| ) -> torch.Tensor: | |
| mask_tokens = self.mask_token.repeat( | |
| x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 | |
| ) | |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token | |
| x_ = torch.gather( | |
| x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) | |
| ) # unshuffle | |
| x = torch.cat([x[:, :1, :], x_], dim=1) # add class token | |
| x = x + self.pos_embeddings | |
| x = self.blocks(x) | |
| x = self.norm(x) | |
| return x # type: ignore[no-any-return] | |
| class CrossAttention(nn.Module): | |
| def __init__( | |
| self, embed_dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0 | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| head_dim = embed_dim // num_heads | |
| self.scale = head_dim**-0.5 | |
| self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) | |
| self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) | |
| self.attn_drop = nn.Dropout(attn_drop) | |
| self.proj = nn.Linear(embed_dim, embed_dim) | |
| self.proj_drop = nn.Dropout(proj_drop) | |
| def forward(self, x, context): | |
| B, N, C = x.shape | |
| _, M, _ = context.shape | |
| q = ( | |
| self.q(x) | |
| .reshape(B, N, self.num_heads, C // self.num_heads) | |
| .permute(0, 2, 1, 3) | |
| ) | |
| kv = ( | |
| self.kv(context) | |
| .reshape(B, M, 2, self.num_heads, C // self.num_heads) | |
| .permute(2, 0, 3, 1, 4) | |
| ) | |
| k, v = kv[0], kv[1] | |
| attn = (q @ k.transpose(-2, -1)) * self.scale | |
| attn = attn.softmax(dim=-1) | |
| attn = self.attn_drop(attn) | |
| x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
| x = self.proj(x) | |
| x = self.proj_drop(x) | |
| return x | |
| class CAMAEDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| num_modalities: int = 6, | |
| tokens_per_modality: int = 256, | |
| embed_dim: int = 256, | |
| depth: int = 2, | |
| num_heads: int = 16, | |
| mlp_ratio: float = 4, | |
| qkv_bias: bool = True, | |
| norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type: ignore[assignment] | |
| ) -> None: | |
| super().__init__() | |
| self.num_modalities = num_modalities | |
| self.tokens_per_modality = tokens_per_modality | |
| self.embed_dim = embed_dim | |
| self.pos_embeddings = None # to be overwritten by MAE class | |
| self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) | |
| self.placeholder = nn.Parameter( | |
| torch.zeros(1, 1, embed_dim), requires_grad=False | |
| ) | |
| self.modality_tokens = nn.ParameterList( | |
| [ | |
| nn.Parameter(torch.zeros(1, 1, self.embed_dim)) | |
| for modality in range(self.num_modalities) | |
| ] | |
| ) | |
| self.cross_attention = CrossAttention(embed_dim=self.embed_dim) | |
| self.mlp = Mlp(self.embed_dim, hidden_features=int(self.embed_dim * mlp_ratio)) | |
| self.decoders = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| *[ | |
| Block( | |
| embed_dim, | |
| num_heads, | |
| mlp_ratio, | |
| qkv_bias=qkv_bias, | |
| norm_layer=norm_layer, | |
| ) | |
| for i in range(depth) | |
| ] | |
| ) | |
| for modality in range(self.num_modalities) | |
| ] | |
| ) | |
| # self.norm = norm_layer(embed_dim) # we decided to drop the last layer norm | |
| self.context_norm = norm_layer(embed_dim) | |
| self.query_norm = norm_layer(embed_dim) | |
| self.out_norm = norm_layer(embed_dim) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x_m_s = [] | |
| modality_tokens_concat = torch.cat( | |
| [ | |
| self.placeholder, | |
| ] # placeholder for class token | |
| + [ | |
| m_t.repeat(1, self.tokens_per_modality, 1) | |
| for m_t in self.modality_tokens | |
| ], | |
| dim=1, | |
| ) | |
| x = ( | |
| x + self.pos_embeddings + modality_tokens_concat | |
| ) # add pos and tiled modality tokens | |
| x_ = x[:, 1:, :] # no class token | |
| for m, decoder in enumerate( | |
| self.decoders | |
| ): # iterate through modalities and decoders | |
| x_m = x_[ | |
| :, m * self.tokens_per_modality : (m + 1) * self.tokens_per_modality, : | |
| ] | |
| x_m = self.cross_attention(self.query_norm(x_m), self.context_norm(x_)) | |
| x_m = x_m + self.mlp(self.out_norm(x_m)) | |
| x_m = decoder(x_m) | |
| x_m_s.append(x_m) | |
| x_m_s = torch.cat(x_m_s, dim=1) # concat all tokens | |
| # x_m_s = self.norm(x_m_s) # we decided to drop the last layer norm | |
| x_m_s = torch.cat([x[:, :1, :], x_m_s], dim=1) # add back class token | |
| return x_m_s | |
| def forward_masked( | |
| self, x: torch.Tensor, ind_restore: torch.Tensor | |
| ) -> torch.Tensor: | |
| mask_tokens = self.mask_token.repeat( | |
| x.shape[0], ind_restore.shape[1] + 1 - x.shape[1], 1 | |
| ) | |
| x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # remove class token | |
| x_ = torch.gather( | |
| x_, dim=1, index=ind_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) | |
| ) # unshuffle | |
| x = torch.cat([x[:, :1, :], x_], dim=1) # add class token | |
| x = self.forward(x) | |
| return x | |