import argparse import random from pathlib import Path import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm from augmentations import get_train_transforms, get_val_transforms from dataloader import AREDSDataset from model import DeepSeeNet N_CLASSES = { "ADVAMD": 2, "DRUS": 3, "PIG": 2, } class AlbumentationsTransform: def __init__(self, transform): self.transform = transform def __call__(self, image): return self.transform(image=np.asarray(image))["image"] def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def get_class_weights(dataset, task, device): labels = torch.tensor(dataset.data[task].to_numpy(), dtype=torch.long) counts = torch.bincount(labels, minlength=N_CLASSES[task]).clamp_min(1) weights = counts.sum() / (len(counts) * counts) return weights.to(device) def build_scheduler(optimizer, args): if args.scheduler == "cosine": return torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=args.epochs, eta_min=args.min_lr, ) if args.scheduler == "step": return torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.step_size, gamma=args.gamma, ) return None def train_one_epoch( model, loader, optimizer, scaler, criterion, device, use_amp=True, grad_clip=0.0, ): model.train() running_loss = 0.0 running_correct = 0 running_samples = 0 pbar = tqdm(loader, desc="Train", leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) optimizer.zero_grad(set_to_none=True) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = model(images) loss = criterion(logits, labels) if scaler is not None: scaler.scale(loss).backward() if grad_clip > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() if grad_clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() batch_size = labels.size(0) running_loss += loss.item() * batch_size running_correct += (logits.argmax(dim=1) == labels).sum().item() running_samples += batch_size pbar.set_postfix( loss=f"{running_loss / running_samples:.4f}", acc=f"{running_correct / running_samples:.4f}", ) return running_loss / running_samples, running_correct / running_samples @torch.no_grad() def evaluate(model, loader, criterion, device, use_amp=True): model.eval() running_loss = 0.0 running_correct = 0 running_samples = 0 pbar = tqdm(loader, desc="Val", leave=False) for images, labels in pbar: images = images.to(device, non_blocking=True) labels = labels.to(device, non_blocking=True) with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): logits = model(images) loss = criterion(logits, labels) batch_size = labels.size(0) running_loss += loss.item() * batch_size running_correct += (logits.argmax(dim=1) == labels).sum().item() running_samples += batch_size pbar.set_postfix( loss=f"{running_loss / running_samples:.4f}", acc=f"{running_correct / running_samples:.4f}", ) return running_loss / running_samples, running_correct / running_samples def save_checkpoint(path, model, optimizer, epoch, best_val_loss, args): path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "epoch": epoch, "model": model.state_dict(), "optimizer": optimizer.state_dict(), "best_val_loss": best_val_loss, "args": vars(args), }, path, ) def main(args): set_seed(args.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") use_amp = args.amp and device.type == "cuda" train_dataset = AREDSDataset( args.train_csv, args.image_root, args.task, transform=AlbumentationsTransform(get_train_transforms(args.image_size)), ) val_dataset = AREDSDataset( args.valid_csv, args.image_root, args.task, transform=AlbumentationsTransform(get_val_transforms(args.image_size)), ) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=device.type == "cuda", ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=device.type == "cuda", ) model = DeepSeeNet( n_classes=N_CLASSES[args.task], backbone=args.backbone, pretrained=not args.no_pretrained, freeze_backbone=args.freeze_backbone, ).to(device) class_weights = None if not args.no_class_weights: class_weights = get_class_weights(train_dataset, args.task, device) criterion = torch.nn.CrossEntropyLoss(weight=class_weights) optimizer = torch.optim.AdamW( model.parameters(), lr=args.lr, weight_decay=args.weight_decay, ) scheduler = build_scheduler(optimizer, args) scaler = torch.amp.GradScaler("cuda") if use_amp else None wandb = None if args.wandb: import wandb wandb.init(project=args.wandb_project, config=vars(args)) output_dir = Path(args.output_dir) best_val_loss = float("inf") print(f"Device: {device}") print(f"Task: {args.task}") print(f"Train samples: {len(train_dataset)}") print(f"Val samples: {len(val_dataset)}") print(f"Image size: {args.image_size}") print(f"Batch size: {args.batch_size}") print(f"Pretrained: {not args.no_pretrained}") if class_weights is not None: print(f"Class weights: {class_weights.detach().cpu().tolist()}") for epoch in range(1, args.epochs + 1): print(f"\nEpoch [{epoch:03d}/{args.epochs}]") train_loss, train_acc = train_one_epoch( model=model, loader=train_loader, optimizer=optimizer, scaler=scaler, criterion=criterion, device=device, use_amp=args.amp, grad_clip=args.grad_clip, ) val_loss, val_acc = evaluate( model=model, loader=val_loader, criterion=torch.nn.CrossEntropyLoss(), device=device, use_amp=args.amp, ) lr = optimizer.param_groups[0]["lr"] print( f"train_loss={train_loss:.4f} " f"train_acc={train_acc:.4f} " f"val_loss={val_loss:.4f} " f"val_acc={val_acc:.4f} " f"lr={lr:.2e}" ) if wandb is not None: wandb.log( { "epoch": epoch, "lr": lr, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc, } ) if val_loss < best_val_loss: best_val_loss = val_loss save_checkpoint( output_dir / "best.pt", model, optimizer, epoch, best_val_loss, args, ) print(f"Saved best checkpoint: val_loss={best_val_loss:.4f}") if args.save_every > 0 and epoch % args.save_every == 0: save_checkpoint( output_dir / f"epoch_{epoch:03d}.pt", model, optimizer, epoch, best_val_loss, args, ) if scheduler is not None: scheduler.step() save_checkpoint( output_dir / "last.pt", model, optimizer, args.epochs, best_val_loss, args, ) print("Training complete.") print(f"Best val loss: {best_val_loss:.4f}") def parse_args(): parser = argparse.ArgumentParser(description="Train DeepSeeNet.") parser.add_argument("--train-csv", required=True) parser.add_argument("--valid-csv", required=True) parser.add_argument("--image-root", required=True) parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES) parser.add_argument("--output-dir", default="checkpoints/deepseenet") parser.add_argument("--backbone", default="inception_v3") parser.add_argument("--image-size", type=int, default=1024) parser.add_argument("--epochs", type=int, default=20) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--no-pretrained", action="store_true") parser.add_argument("--freeze-backbone", action="store_true") parser.add_argument("--no-class-weights", action="store_true") parser.add_argument("--scheduler", choices=("none", "cosine", "step"), default="cosine") parser.add_argument("--min-lr", type=float, default=1e-6) parser.add_argument("--step-size", type=int, default=5) parser.add_argument("--gamma", type=float, default=0.5) parser.add_argument("--amp", action="store_true") parser.add_argument("--grad-clip", type=float, default=0.0) parser.add_argument("--save-every", type=int, default=0) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--wandb", action="store_true") parser.add_argument("--wandb-project", default="deepseenet") return parser.parse_args() if __name__ == "__main__": args = parse_args() main(args)