import torch from torch.utils.data import DataLoader, TensorDataset import numpy as np from model import CustomBERTModel from config import Config import pandas as pd def load_data(file_path): df = pd.read_csv(file_path, header=None) return torch.tensor(df.values, dtype=torch.float32) def get_embeddings(input_file, output_file): config = Config() model = CustomBERTModel(config).to(config.device) model.load_state_dict(torch.load("bert_mlm_model.pth")) model.eval() input_data = load_data(input_file) dataset = TensorDataset(input_data) data_loader = DataLoader(dataset, batch_size=config.batch_size) all_embeddings = [] with torch.no_grad(): for batch in data_loader: inputs = batch[0].to(config.device) embeddings = model.get_encoder_output(inputs) all_embeddings.append(embeddings.cpu().numpy()) all_embeddings = np.concatenate(all_embeddings, axis=0) print(f"Generated embeddings shape: {all_embeddings.shape}") # Save embeddings np.save(output_file, all_embeddings) print(f"Embeddings saved as {output_file}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves") parser.add_argument("input_file", help="Path to the input CSV file containing growth curves") parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)") args = parser.parse_args() get_embeddings(args.input_file, args.output_file)