CULTURE / get_embeddings.py
pranamanam's picture
Upload 7 files
a0e0ff1 verified
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from model import CustomBERTModel
from config import Config
import pandas as pd
def load_data(file_path):
df = pd.read_csv(file_path, header=None)
return torch.tensor(df.values, dtype=torch.float32)
def get_embeddings(input_file, output_file):
config = Config()
model = CustomBERTModel(config).to(config.device)
model.load_state_dict(torch.load("bert_mlm_model.pth"))
model.eval()
input_data = load_data(input_file)
dataset = TensorDataset(input_data)
data_loader = DataLoader(dataset, batch_size=config.batch_size)
all_embeddings = []
with torch.no_grad():
for batch in data_loader:
inputs = batch[0].to(config.device)
embeddings = model.get_encoder_output(inputs)
all_embeddings.append(embeddings.cpu().numpy())
all_embeddings = np.concatenate(all_embeddings, axis=0)
print(f"Generated embeddings shape: {all_embeddings.shape}")
# Save embeddings
np.save(output_file, all_embeddings)
print(f"Embeddings saved as {output_file}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves")
parser.add_argument("input_file", help="Path to the input CSV file containing growth curves")
parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)")
args = parser.parse_args()
get_embeddings(args.input_file, args.output_file)