Spaces:
Sleeping
Sleeping
| import argparse | |
| import random | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from augmentations import get_train_transforms, get_val_transforms | |
| from dataloader import AREDSDataset | |
| from model import DeepSeeNet | |
| N_CLASSES = { | |
| "ADVAMD": 2, | |
| "DRUS": 3, | |
| "PIG": 2, | |
| } | |
| class AlbumentationsTransform: | |
| def __init__(self, transform): | |
| self.transform = transform | |
| def __call__(self, image): | |
| return self.transform(image=np.asarray(image))["image"] | |
| def set_seed(seed): | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| def get_class_weights(dataset, task, device): | |
| labels = torch.tensor(dataset.data[task].to_numpy(), dtype=torch.long) | |
| counts = torch.bincount(labels, minlength=N_CLASSES[task]).clamp_min(1) | |
| weights = counts.sum() / (len(counts) * counts) | |
| return weights.to(device) | |
| def build_scheduler(optimizer, args): | |
| if args.scheduler == "cosine": | |
| return torch.optim.lr_scheduler.CosineAnnealingLR( | |
| optimizer, | |
| T_max=args.epochs, | |
| eta_min=args.min_lr, | |
| ) | |
| if args.scheduler == "step": | |
| return torch.optim.lr_scheduler.StepLR( | |
| optimizer, | |
| step_size=args.step_size, | |
| gamma=args.gamma, | |
| ) | |
| return None | |
| def train_one_epoch( | |
| model, | |
| loader, | |
| optimizer, | |
| scaler, | |
| criterion, | |
| device, | |
| use_amp=True, | |
| grad_clip=0.0, | |
| ): | |
| model.train() | |
| running_loss = 0.0 | |
| running_correct = 0 | |
| running_samples = 0 | |
| pbar = tqdm(loader, desc="Train", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| if scaler is not None: | |
| scaler.scale(loss).backward() | |
| if grad_clip > 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| else: | |
| loss.backward() | |
| if grad_clip > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) | |
| optimizer.step() | |
| batch_size = labels.size(0) | |
| running_loss += loss.item() * batch_size | |
| running_correct += (logits.argmax(dim=1) == labels).sum().item() | |
| running_samples += batch_size | |
| pbar.set_postfix( | |
| loss=f"{running_loss / running_samples:.4f}", | |
| acc=f"{running_correct / running_samples:.4f}", | |
| ) | |
| return running_loss / running_samples, running_correct / running_samples | |
| def evaluate(model, loader, criterion, device, use_amp=True): | |
| model.eval() | |
| running_loss = 0.0 | |
| running_correct = 0 | |
| running_samples = 0 | |
| pbar = tqdm(loader, desc="Val", leave=False) | |
| for images, labels in pbar: | |
| images = images.to(device, non_blocking=True) | |
| labels = labels.to(device, non_blocking=True) | |
| with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"): | |
| logits = model(images) | |
| loss = criterion(logits, labels) | |
| batch_size = labels.size(0) | |
| running_loss += loss.item() * batch_size | |
| running_correct += (logits.argmax(dim=1) == labels).sum().item() | |
| running_samples += batch_size | |
| pbar.set_postfix( | |
| loss=f"{running_loss / running_samples:.4f}", | |
| acc=f"{running_correct / running_samples:.4f}", | |
| ) | |
| return running_loss / running_samples, running_correct / running_samples | |
| def save_checkpoint(path, model, optimizer, epoch, best_val_loss, args): | |
| path = Path(path) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| torch.save( | |
| { | |
| "epoch": epoch, | |
| "model": model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "best_val_loss": best_val_loss, | |
| "args": vars(args), | |
| }, | |
| path, | |
| ) | |
| def main(args): | |
| set_seed(args.seed) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| use_amp = args.amp and device.type == "cuda" | |
| train_dataset = AREDSDataset( | |
| args.train_csv, | |
| args.image_root, | |
| args.task, | |
| transform=AlbumentationsTransform(get_train_transforms(args.image_size)), | |
| ) | |
| val_dataset = AREDSDataset( | |
| args.valid_csv, | |
| args.image_root, | |
| args.task, | |
| transform=AlbumentationsTransform(get_val_transforms(args.image_size)), | |
| ) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=device.type == "cuda", | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=device.type == "cuda", | |
| ) | |
| model = DeepSeeNet( | |
| n_classes=N_CLASSES[args.task], | |
| backbone=args.backbone, | |
| pretrained=not args.no_pretrained, | |
| freeze_backbone=args.freeze_backbone, | |
| ).to(device) | |
| class_weights = None | |
| if not args.no_class_weights: | |
| class_weights = get_class_weights(train_dataset, args.task, device) | |
| criterion = torch.nn.CrossEntropyLoss(weight=class_weights) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| scheduler = build_scheduler(optimizer, args) | |
| scaler = torch.amp.GradScaler("cuda") if use_amp else None | |
| wandb = None | |
| if args.wandb: | |
| import wandb | |
| wandb.init(project=args.wandb_project, config=vars(args)) | |
| output_dir = Path(args.output_dir) | |
| best_val_loss = float("inf") | |
| print(f"Device: {device}") | |
| print(f"Task: {args.task}") | |
| print(f"Train samples: {len(train_dataset)}") | |
| print(f"Val samples: {len(val_dataset)}") | |
| print(f"Image size: {args.image_size}") | |
| print(f"Batch size: {args.batch_size}") | |
| print(f"Pretrained: {not args.no_pretrained}") | |
| if class_weights is not None: | |
| print(f"Class weights: {class_weights.detach().cpu().tolist()}") | |
| for epoch in range(1, args.epochs + 1): | |
| print(f"\nEpoch [{epoch:03d}/{args.epochs}]") | |
| train_loss, train_acc = train_one_epoch( | |
| model=model, | |
| loader=train_loader, | |
| optimizer=optimizer, | |
| scaler=scaler, | |
| criterion=criterion, | |
| device=device, | |
| use_amp=args.amp, | |
| grad_clip=args.grad_clip, | |
| ) | |
| val_loss, val_acc = evaluate( | |
| model=model, | |
| loader=val_loader, | |
| criterion=torch.nn.CrossEntropyLoss(), | |
| device=device, | |
| use_amp=args.amp, | |
| ) | |
| lr = optimizer.param_groups[0]["lr"] | |
| print( | |
| f"train_loss={train_loss:.4f} " | |
| f"train_acc={train_acc:.4f} " | |
| f"val_loss={val_loss:.4f} " | |
| f"val_acc={val_acc:.4f} " | |
| f"lr={lr:.2e}" | |
| ) | |
| if wandb is not None: | |
| wandb.log( | |
| { | |
| "epoch": epoch, | |
| "lr": lr, | |
| "train_loss": train_loss, | |
| "train_acc": train_acc, | |
| "val_loss": val_loss, | |
| "val_acc": val_acc, | |
| } | |
| ) | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| save_checkpoint( | |
| output_dir / "best.pt", | |
| model, | |
| optimizer, | |
| epoch, | |
| best_val_loss, | |
| args, | |
| ) | |
| print(f"Saved best checkpoint: val_loss={best_val_loss:.4f}") | |
| if args.save_every > 0 and epoch % args.save_every == 0: | |
| save_checkpoint( | |
| output_dir / f"epoch_{epoch:03d}.pt", | |
| model, | |
| optimizer, | |
| epoch, | |
| best_val_loss, | |
| args, | |
| ) | |
| if scheduler is not None: | |
| scheduler.step() | |
| save_checkpoint( | |
| output_dir / "last.pt", | |
| model, | |
| optimizer, | |
| args.epochs, | |
| best_val_loss, | |
| args, | |
| ) | |
| print("Training complete.") | |
| print(f"Best val loss: {best_val_loss:.4f}") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Train DeepSeeNet.") | |
| parser.add_argument("--train-csv", required=True) | |
| parser.add_argument("--valid-csv", required=True) | |
| parser.add_argument("--image-root", required=True) | |
| parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES) | |
| parser.add_argument("--output-dir", default="checkpoints/deepseenet") | |
| parser.add_argument("--backbone", default="inception_v3") | |
| parser.add_argument("--image-size", type=int, default=1024) | |
| parser.add_argument("--epochs", type=int, default=20) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--num-workers", type=int, default=4) | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--weight-decay", type=float, default=1e-4) | |
| parser.add_argument("--no-pretrained", action="store_true") | |
| parser.add_argument("--freeze-backbone", action="store_true") | |
| parser.add_argument("--no-class-weights", action="store_true") | |
| parser.add_argument("--scheduler", choices=("none", "cosine", "step"), default="cosine") | |
| parser.add_argument("--min-lr", type=float, default=1e-6) | |
| parser.add_argument("--step-size", type=int, default=5) | |
| parser.add_argument("--gamma", type=float, default=0.5) | |
| parser.add_argument("--amp", action="store_true") | |
| parser.add_argument("--grad-clip", type=float, default=0.0) | |
| parser.add_argument("--save-every", type=int, default=0) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--wandb", action="store_true") | |
| parser.add_argument("--wandb-project", default="deepseenet") | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |