from __future__ import annotations import argparse import json from pathlib import Path from typing import Any, Callable import numpy as np import pandas as pd import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm try: from sklearn.metrics import ( accuracy_score, cohen_kappa_score, confusion_matrix, recall_score, roc_auc_score, ) except ImportError as exc: raise ImportError( "This evaluation script needs scikit-learn. Install with: pip install scikit-learn" ) from exc from augmentations import get_val_transforms from dataloader import AREDSDataset from model import DeepSeeNet N_CLASSES = { "ADVAMD": 2, "DRUS": 3, "PIG": 2, } DEFAULT_POSITIVE_CLASS = { "ADVAMD": 1, "DRUS": 2, "PIG": 1, } ENDPOINT_NAME = { "ADVAMD": "late_amd", "DRUS": "large_drusen", "PIG": "pigmentary_abnormality", } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--test-csv", required=True) parser.add_argument("--image-root", required=True) parser.add_argument("--checkpoint", required=True) parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES) parser.add_argument("--backbone", default="inception_v3") parser.add_argument("--image-size", type=int, default=1024) parser.add_argument("--batch-size", type=int, default=32) parser.add_argument("--num-workers", type=int, default=16) parser.add_argument("--positive-class", type=int, default=None) parser.add_argument("--bootstrap-iters", type=int, default=2000) parser.add_argument("--seed", type=int, default=123) parser.add_argument("--bootstrap-unit-column", default=None) parser.add_argument("--output-dir", default=None) return parser.parse_args() class AlbumentationsTransform: def __init__(self, transform) -> None: self.transform = transform def __call__(self, image): return self.transform(image=np.asarray(image))["image"] @torch.no_grad() def collect_predictions(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> dict[str, np.ndarray | float]: model.eval() total_loss = 0.0 total_samples = 0 all_labels: list[np.ndarray] = [] all_logits: list[np.ndarray] = [] for images, labels in tqdm(loader, desc="test"): images = images.to(device) labels = labels.to(device) logits = model(images) if isinstance(logits, (tuple, list)): logits = logits[0] loss = F.cross_entropy(logits, labels) batch_size = labels.size(0) total_loss += loss.item() * batch_size total_samples += batch_size all_labels.append(labels.detach().cpu().numpy()) all_logits.append(logits.detach().cpu().numpy()) labels_np = np.concatenate(all_labels).astype(int) logits_np = np.concatenate(all_logits, axis=0) probs_np = torch.softmax(torch.from_numpy(logits_np), dim=1).numpy() preds_np = probs_np.argmax(axis=1).astype(int) return { "loss": float(total_loss / max(total_samples, 1)), "labels": labels_np, "logits": logits_np, "probs": probs_np, "preds": preds_np, } def specificity_score(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> float: tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0)) fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1)) denom = tn + fp return float(tn / denom) if denom else float("nan") def safe_auc(y_true_bin: np.ndarray, y_score: np.ndarray) -> float: if len(np.unique(y_true_bin)) < 2: return float("nan") return float(roc_auc_score(y_true_bin, y_score)) def compute_metrics( y_true: np.ndarray, y_pred: np.ndarray, probs: np.ndarray, n_classes: int, positive_class: int, ) -> dict[str, float]: y_true_bin = (y_true == positive_class).astype(int) y_pred_bin = (y_pred == positive_class).astype(int) pos_score = probs[:, positive_class] metrics = { "loss": float("nan"), "exact_accuracy": float(accuracy_score(y_true, y_pred)), "exact_kappa": float(cohen_kappa_score(y_true, y_pred)), "overall_accuracy": float(accuracy_score(y_true_bin, y_pred_bin)), "sensitivity": float(recall_score(y_true_bin, y_pred_bin, pos_label=1, zero_division=0)), "specificity": specificity_score(y_true_bin, y_pred_bin), "kappa": float(cohen_kappa_score(y_true_bin, y_pred_bin)), "auc": safe_auc(y_true_bin, pos_score), } if n_classes > 2 and len(np.unique(y_true)) > 1: try: metrics["macro_ovr_auc"] = float( roc_auc_score(y_true, probs, labels=list(range(n_classes)), multi_class="ovr", average="macro") ) except ValueError: metrics["macro_ovr_auc"] = float("nan") return metrics def make_bootstrap_indices( n: int, n_iters: int, rng: np.random.Generator, units: np.ndarray | None = None, ) -> list[np.ndarray]: if n_iters <= 0: return [] if units is None: return [rng.integers(0, n, size=n) for _ in range(n_iters)] unique_units = np.array(pd.unique(units)) row_indices_by_unit = {unit: np.where(units == unit)[0] for unit in unique_units} out = [] for _ in range(n_iters): sampled_units = rng.choice(unique_units, size=len(unique_units), replace=True) out.append(np.concatenate([row_indices_by_unit[u] for u in sampled_units])) return out def bootstrap_ci( metric_fn: Callable[[np.ndarray], dict[str, float]], indices: list[np.ndarray], ) -> dict[str, dict[str, float]]: if not indices: return {} values_by_metric: dict[str, list[float]] = {} for idx in tqdm(indices, desc="bootstrap", leave=False): vals = metric_fn(idx) for key, value in vals.items(): values_by_metric.setdefault(key, []).append(value) intervals: dict[str, dict[str, float]] = {} for key, values in values_by_metric.items(): arr = np.asarray(values, dtype=float) intervals[key] = { "ci_low": float(np.nanpercentile(arr, 2.5)), "ci_high": float(np.nanpercentile(arr, 97.5)), } return intervals def combine_with_ci(metrics: dict[str, float], ci: dict[str, dict[str, float]]) -> dict[str, Any]: out: dict[str, Any] = {} for key, value in metrics.items(): out[key] = {"value": float(value)} if key in ci: out[key].update(ci[key]) return out def print_metric_table(metrics_with_ci: dict[str, Any]) -> None: print("\nMetrics") print("-------") for key in ["overall_accuracy", "sensitivity", "specificity", "kappa", "auc"]: item = metrics_with_ci[key] if "ci_low" in item: print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})") else: print(f"{key:20s} {item['value']:.4f}") print("\nClassifier metrics") print("------------------") for key in ["loss", "exact_accuracy", "exact_kappa", "macro_ovr_auc"]: if key not in metrics_with_ci: continue item = metrics_with_ci[key] if "ci_low" in item: print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})") else: print(f"{key:20s} {item['value']:.4f}") def main() -> None: args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") task = args.task.upper() n_classes = N_CLASSES[task] positive_class = DEFAULT_POSITIVE_CLASS[task] if args.positive_class is None else args.positive_class if not 0 <= positive_class < n_classes: raise ValueError(f"positive_class={positive_class} is invalid for task={task} with {n_classes} classes") dataset = AREDSDataset( args.test_csv, args.image_root, task, transform=AlbumentationsTransform(get_val_transforms(args.image_size)), ) loader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=device.type == "cuda", ) model = DeepSeeNet( n_classes=n_classes, backbone=args.backbone, pretrained=False, ).to(device) checkpoint = torch.load(args.checkpoint, map_location=device) model.load_state_dict(checkpoint["model"]) pred_dict = collect_predictions(model, loader, device) y_true = pred_dict["labels"] y_pred = pred_dict["preds"] probs = pred_dict["probs"] metrics = compute_metrics(y_true, y_pred, probs, n_classes=n_classes, positive_class=positive_class) metrics["loss"] = float(pred_dict["loss"]) units = None if args.bootstrap_unit_column: df_for_units = pd.read_csv(args.test_csv) if args.bootstrap_unit_column not in df_for_units.columns: raise KeyError( f"--bootstrap-unit-column {args.bootstrap_unit_column!r} not found in {args.test_csv}. " f"Available columns: {list(df_for_units.columns)}" ) if len(df_for_units) != len(y_true): raise ValueError( "CSV length does not match dataset length. " f"CSV rows={len(df_for_units)}, dataset rows={len(y_true)}" ) units = df_for_units[args.bootstrap_unit_column].to_numpy() rng = np.random.default_rng(args.seed) bs_indices = make_bootstrap_indices( n=len(y_true), n_iters=args.bootstrap_iters, rng=rng, units=units, ) def metric_fn(idx: np.ndarray) -> dict[str, float]: out = compute_metrics( y_true[idx], y_pred[idx], probs[idx], n_classes=n_classes, positive_class=positive_class, ) out.pop("loss", None) return out ci = bootstrap_ci(metric_fn, bs_indices) metrics_with_ci = combine_with_ci(metrics, ci) cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes))) endpoint_cm = confusion_matrix( (y_true == positive_class).astype(int), (y_pred == positive_class).astype(int), labels=[0, 1], ) meta = { "task": task, "endpoint": ENDPOINT_NAME[task], "positive_class": int(positive_class), "n_classes": int(n_classes), "n_samples": int(len(y_true)), "bootstrap_iters": int(args.bootstrap_iters), "bootstrap_unit_column": args.bootstrap_unit_column, } print(f"\nTask: {task} | endpoint: {ENDPOINT_NAME[task]} | positive_class={positive_class}") print_metric_table(metrics_with_ci) print("\nConfusion matrix (rows=true, cols=pred):") print(cm) print("\nBinary confusion matrix (rows=true, cols=pred):") print(endpoint_cm) if args.output_dir: output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) with (output_dir / "metrics.json").open("w") as f: json.dump({"meta": meta, "metrics": metrics_with_ci}, f, indent=2) pd.DataFrame(cm).to_csv(output_dir / "confusion_matrix.csv", index=False) pd.DataFrame(endpoint_cm, index=["true_neg", "true_pos"], columns=["pred_neg", "pred_pos"]).to_csv( output_dir / "endpoint_confusion_matrix.csv" ) pred_df = pd.read_csv(args.test_csv) if len(pred_df) == len(y_true): pred_df = pred_df.copy() else: pred_df = pd.DataFrame(index=np.arange(len(y_true))) pred_df["y_true"] = y_true pred_df["y_pred"] = y_pred pred_df[f"y_true_{ENDPOINT_NAME[task]}"] = (y_true == positive_class).astype(int) pred_df[f"y_pred_{ENDPOINT_NAME[task]}"] = (y_pred == positive_class).astype(int) for c in range(n_classes): pred_df[f"prob_class_{c}"] = probs[:, c] pred_df.to_csv(output_dir / "predictions.csv", index=False) print(f"\nSaved outputs to: {output_dir}") if __name__ == "__main__": main()