| import torch |
| from torch.utils.data import DataLoader, TensorDataset |
| import numpy as np |
| from model import CustomBERTModel |
| from config import Config |
| import pandas as pd |
|
|
| def load_data(file_path): |
| df = pd.read_csv(file_path, header=None) |
| return torch.tensor(df.values, dtype=torch.float32) |
|
|
| def get_embeddings(input_file, output_file): |
| config = Config() |
| model = CustomBERTModel(config).to(config.device) |
| model.load_state_dict(torch.load("bert_mlm_model.pth")) |
| model.eval() |
|
|
| input_data = load_data(input_file) |
| dataset = TensorDataset(input_data) |
| data_loader = DataLoader(dataset, batch_size=config.batch_size) |
|
|
| all_embeddings = [] |
|
|
| with torch.no_grad(): |
| for batch in data_loader: |
| inputs = batch[0].to(config.device) |
| embeddings = model.get_encoder_output(inputs) |
| all_embeddings.append(embeddings.cpu().numpy()) |
|
|
| all_embeddings = np.concatenate(all_embeddings, axis=0) |
| print(f"Generated embeddings shape: {all_embeddings.shape}") |
|
|
| |
| np.save(output_file, all_embeddings) |
| print(f"Embeddings saved as {output_file}") |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser(description="Generate embeddings for microbial growth curves") |
| parser.add_argument("input_file", help="Path to the input CSV file containing growth curves") |
| parser.add_argument("output_file", help="Path to save the output embeddings (as .npy file)") |
| args = parser.parse_args() |
|
|
| get_embeddings(args.input_file, args.output_file) |
|
|