DeepSeeNet / train.py
farrell236's picture
add src
b8c9192
import argparse
import random
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from augmentations import get_train_transforms, get_val_transforms
from dataloader import AREDSDataset
from model import DeepSeeNet
N_CLASSES = {
"ADVAMD": 2,
"DRUS": 3,
"PIG": 2,
}
class AlbumentationsTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, image):
return self.transform(image=np.asarray(image))["image"]
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_class_weights(dataset, task, device):
labels = torch.tensor(dataset.data[task].to_numpy(), dtype=torch.long)
counts = torch.bincount(labels, minlength=N_CLASSES[task]).clamp_min(1)
weights = counts.sum() / (len(counts) * counts)
return weights.to(device)
def build_scheduler(optimizer, args):
if args.scheduler == "cosine":
return torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=args.epochs,
eta_min=args.min_lr,
)
if args.scheduler == "step":
return torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=args.step_size,
gamma=args.gamma,
)
return None
def train_one_epoch(
model,
loader,
optimizer,
scaler,
criterion,
device,
use_amp=True,
grad_clip=0.0,
):
model.train()
running_loss = 0.0
running_correct = 0
running_samples = 0
pbar = tqdm(loader, desc="Train", leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = model(images)
loss = criterion(logits, labels)
if scaler is not None:
scaler.scale(loss).backward()
if grad_clip > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
if grad_clip > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
batch_size = labels.size(0)
running_loss += loss.item() * batch_size
running_correct += (logits.argmax(dim=1) == labels).sum().item()
running_samples += batch_size
pbar.set_postfix(
loss=f"{running_loss / running_samples:.4f}",
acc=f"{running_correct / running_samples:.4f}",
)
return running_loss / running_samples, running_correct / running_samples
@torch.no_grad()
def evaluate(model, loader, criterion, device, use_amp=True):
model.eval()
running_loss = 0.0
running_correct = 0
running_samples = 0
pbar = tqdm(loader, desc="Val", leave=False)
for images, labels in pbar:
images = images.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = model(images)
loss = criterion(logits, labels)
batch_size = labels.size(0)
running_loss += loss.item() * batch_size
running_correct += (logits.argmax(dim=1) == labels).sum().item()
running_samples += batch_size
pbar.set_postfix(
loss=f"{running_loss / running_samples:.4f}",
acc=f"{running_correct / running_samples:.4f}",
)
return running_loss / running_samples, running_correct / running_samples
def save_checkpoint(path, model, optimizer, epoch, best_val_loss, args):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"best_val_loss": best_val_loss,
"args": vars(args),
},
path,
)
def main(args):
set_seed(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = args.amp and device.type == "cuda"
train_dataset = AREDSDataset(
args.train_csv,
args.image_root,
args.task,
transform=AlbumentationsTransform(get_train_transforms(args.image_size)),
)
val_dataset = AREDSDataset(
args.valid_csv,
args.image_root,
args.task,
transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
)
model = DeepSeeNet(
n_classes=N_CLASSES[args.task],
backbone=args.backbone,
pretrained=not args.no_pretrained,
freeze_backbone=args.freeze_backbone,
).to(device)
class_weights = None
if not args.no_class_weights:
class_weights = get_class_weights(train_dataset, args.task, device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
scheduler = build_scheduler(optimizer, args)
scaler = torch.amp.GradScaler("cuda") if use_amp else None
wandb = None
if args.wandb:
import wandb
wandb.init(project=args.wandb_project, config=vars(args))
output_dir = Path(args.output_dir)
best_val_loss = float("inf")
print(f"Device: {device}")
print(f"Task: {args.task}")
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Image size: {args.image_size}")
print(f"Batch size: {args.batch_size}")
print(f"Pretrained: {not args.no_pretrained}")
if class_weights is not None:
print(f"Class weights: {class_weights.detach().cpu().tolist()}")
for epoch in range(1, args.epochs + 1):
print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
train_loss, train_acc = train_one_epoch(
model=model,
loader=train_loader,
optimizer=optimizer,
scaler=scaler,
criterion=criterion,
device=device,
use_amp=args.amp,
grad_clip=args.grad_clip,
)
val_loss, val_acc = evaluate(
model=model,
loader=val_loader,
criterion=torch.nn.CrossEntropyLoss(),
device=device,
use_amp=args.amp,
)
lr = optimizer.param_groups[0]["lr"]
print(
f"train_loss={train_loss:.4f} "
f"train_acc={train_acc:.4f} "
f"val_loss={val_loss:.4f} "
f"val_acc={val_acc:.4f} "
f"lr={lr:.2e}"
)
if wandb is not None:
wandb.log(
{
"epoch": epoch,
"lr": lr,
"train_loss": train_loss,
"train_acc": train_acc,
"val_loss": val_loss,
"val_acc": val_acc,
}
)
if val_loss < best_val_loss:
best_val_loss = val_loss
save_checkpoint(
output_dir / "best.pt",
model,
optimizer,
epoch,
best_val_loss,
args,
)
print(f"Saved best checkpoint: val_loss={best_val_loss:.4f}")
if args.save_every > 0 and epoch % args.save_every == 0:
save_checkpoint(
output_dir / f"epoch_{epoch:03d}.pt",
model,
optimizer,
epoch,
best_val_loss,
args,
)
if scheduler is not None:
scheduler.step()
save_checkpoint(
output_dir / "last.pt",
model,
optimizer,
args.epochs,
best_val_loss,
args,
)
print("Training complete.")
print(f"Best val loss: {best_val_loss:.4f}")
def parse_args():
parser = argparse.ArgumentParser(description="Train DeepSeeNet.")
parser.add_argument("--train-csv", required=True)
parser.add_argument("--valid-csv", required=True)
parser.add_argument("--image-root", required=True)
parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
parser.add_argument("--output-dir", default="checkpoints/deepseenet")
parser.add_argument("--backbone", default="inception_v3")
parser.add_argument("--image-size", type=int, default=1024)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--no-pretrained", action="store_true")
parser.add_argument("--freeze-backbone", action="store_true")
parser.add_argument("--no-class-weights", action="store_true")
parser.add_argument("--scheduler", choices=("none", "cosine", "step"), default="cosine")
parser.add_argument("--min-lr", type=float, default=1e-6)
parser.add_argument("--step-size", type=int, default=5)
parser.add_argument("--gamma", type=float, default=0.5)
parser.add_argument("--amp", action="store_true")
parser.add_argument("--grad-clip", type=float, default=0.0)
parser.add_argument("--save-every", type=int, default=0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--wandb", action="store_true")
parser.add_argument("--wandb-project", default="deepseenet")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)