Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| try: | |
| from sklearn.metrics import ( | |
| accuracy_score, | |
| cohen_kappa_score, | |
| confusion_matrix, | |
| recall_score, | |
| roc_auc_score, | |
| ) | |
| except ImportError as exc: | |
| raise ImportError( | |
| "This evaluation script needs scikit-learn. Install with: pip install scikit-learn" | |
| ) from exc | |
| from augmentations import get_val_transforms | |
| from dataloader import AREDSDataset | |
| from model import DeepSeeNet | |
| N_CLASSES = { | |
| "ADVAMD": 2, | |
| "DRUS": 3, | |
| "PIG": 2, | |
| } | |
| DEFAULT_POSITIVE_CLASS = { | |
| "ADVAMD": 1, | |
| "DRUS": 2, | |
| "PIG": 1, | |
| } | |
| ENDPOINT_NAME = { | |
| "ADVAMD": "late_amd", | |
| "DRUS": "large_drusen", | |
| "PIG": "pigmentary_abnormality", | |
| } | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--test-csv", required=True) | |
| parser.add_argument("--image-root", required=True) | |
| parser.add_argument("--checkpoint", required=True) | |
| parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES) | |
| parser.add_argument("--backbone", default="inception_v3") | |
| parser.add_argument("--image-size", type=int, default=1024) | |
| parser.add_argument("--batch-size", type=int, default=32) | |
| parser.add_argument("--num-workers", type=int, default=16) | |
| parser.add_argument("--positive-class", type=int, default=None) | |
| parser.add_argument("--bootstrap-iters", type=int, default=2000) | |
| parser.add_argument("--seed", type=int, default=123) | |
| parser.add_argument("--bootstrap-unit-column", default=None) | |
| parser.add_argument("--output-dir", default=None) | |
| return parser.parse_args() | |
| class AlbumentationsTransform: | |
| def __init__(self, transform) -> None: | |
| self.transform = transform | |
| def __call__(self, image): | |
| return self.transform(image=np.asarray(image))["image"] | |
| def collect_predictions(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> dict[str, np.ndarray | float]: | |
| model.eval() | |
| total_loss = 0.0 | |
| total_samples = 0 | |
| all_labels: list[np.ndarray] = [] | |
| all_logits: list[np.ndarray] = [] | |
| for images, labels in tqdm(loader, desc="test"): | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| logits = model(images) | |
| if isinstance(logits, (tuple, list)): | |
| logits = logits[0] | |
| loss = F.cross_entropy(logits, labels) | |
| batch_size = labels.size(0) | |
| total_loss += loss.item() * batch_size | |
| total_samples += batch_size | |
| all_labels.append(labels.detach().cpu().numpy()) | |
| all_logits.append(logits.detach().cpu().numpy()) | |
| labels_np = np.concatenate(all_labels).astype(int) | |
| logits_np = np.concatenate(all_logits, axis=0) | |
| probs_np = torch.softmax(torch.from_numpy(logits_np), dim=1).numpy() | |
| preds_np = probs_np.argmax(axis=1).astype(int) | |
| return { | |
| "loss": float(total_loss / max(total_samples, 1)), | |
| "labels": labels_np, | |
| "logits": logits_np, | |
| "probs": probs_np, | |
| "preds": preds_np, | |
| } | |
| def specificity_score(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> float: | |
| tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0)) | |
| fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1)) | |
| denom = tn + fp | |
| return float(tn / denom) if denom else float("nan") | |
| def safe_auc(y_true_bin: np.ndarray, y_score: np.ndarray) -> float: | |
| if len(np.unique(y_true_bin)) < 2: | |
| return float("nan") | |
| return float(roc_auc_score(y_true_bin, y_score)) | |
| def compute_metrics( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| probs: np.ndarray, | |
| n_classes: int, | |
| positive_class: int, | |
| ) -> dict[str, float]: | |
| y_true_bin = (y_true == positive_class).astype(int) | |
| y_pred_bin = (y_pred == positive_class).astype(int) | |
| pos_score = probs[:, positive_class] | |
| metrics = { | |
| "loss": float("nan"), | |
| "exact_accuracy": float(accuracy_score(y_true, y_pred)), | |
| "exact_kappa": float(cohen_kappa_score(y_true, y_pred)), | |
| "overall_accuracy": float(accuracy_score(y_true_bin, y_pred_bin)), | |
| "sensitivity": float(recall_score(y_true_bin, y_pred_bin, pos_label=1, zero_division=0)), | |
| "specificity": specificity_score(y_true_bin, y_pred_bin), | |
| "kappa": float(cohen_kappa_score(y_true_bin, y_pred_bin)), | |
| "auc": safe_auc(y_true_bin, pos_score), | |
| } | |
| if n_classes > 2 and len(np.unique(y_true)) > 1: | |
| try: | |
| metrics["macro_ovr_auc"] = float( | |
| roc_auc_score(y_true, probs, labels=list(range(n_classes)), multi_class="ovr", average="macro") | |
| ) | |
| except ValueError: | |
| metrics["macro_ovr_auc"] = float("nan") | |
| return metrics | |
| def make_bootstrap_indices( | |
| n: int, | |
| n_iters: int, | |
| rng: np.random.Generator, | |
| units: np.ndarray | None = None, | |
| ) -> list[np.ndarray]: | |
| if n_iters <= 0: | |
| return [] | |
| if units is None: | |
| return [rng.integers(0, n, size=n) for _ in range(n_iters)] | |
| unique_units = np.array(pd.unique(units)) | |
| row_indices_by_unit = {unit: np.where(units == unit)[0] for unit in unique_units} | |
| out = [] | |
| for _ in range(n_iters): | |
| sampled_units = rng.choice(unique_units, size=len(unique_units), replace=True) | |
| out.append(np.concatenate([row_indices_by_unit[u] for u in sampled_units])) | |
| return out | |
| def bootstrap_ci( | |
| metric_fn: Callable[[np.ndarray], dict[str, float]], | |
| indices: list[np.ndarray], | |
| ) -> dict[str, dict[str, float]]: | |
| if not indices: | |
| return {} | |
| values_by_metric: dict[str, list[float]] = {} | |
| for idx in tqdm(indices, desc="bootstrap", leave=False): | |
| vals = metric_fn(idx) | |
| for key, value in vals.items(): | |
| values_by_metric.setdefault(key, []).append(value) | |
| intervals: dict[str, dict[str, float]] = {} | |
| for key, values in values_by_metric.items(): | |
| arr = np.asarray(values, dtype=float) | |
| intervals[key] = { | |
| "ci_low": float(np.nanpercentile(arr, 2.5)), | |
| "ci_high": float(np.nanpercentile(arr, 97.5)), | |
| } | |
| return intervals | |
| def combine_with_ci(metrics: dict[str, float], ci: dict[str, dict[str, float]]) -> dict[str, Any]: | |
| out: dict[str, Any] = {} | |
| for key, value in metrics.items(): | |
| out[key] = {"value": float(value)} | |
| if key in ci: | |
| out[key].update(ci[key]) | |
| return out | |
| def print_metric_table(metrics_with_ci: dict[str, Any]) -> None: | |
| print("\nMetrics") | |
| print("-------") | |
| for key in ["overall_accuracy", "sensitivity", "specificity", "kappa", "auc"]: | |
| item = metrics_with_ci[key] | |
| if "ci_low" in item: | |
| print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})") | |
| else: | |
| print(f"{key:20s} {item['value']:.4f}") | |
| print("\nClassifier metrics") | |
| print("------------------") | |
| for key in ["loss", "exact_accuracy", "exact_kappa", "macro_ovr_auc"]: | |
| if key not in metrics_with_ci: | |
| continue | |
| item = metrics_with_ci[key] | |
| if "ci_low" in item: | |
| print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})") | |
| else: | |
| print(f"{key:20s} {item['value']:.4f}") | |
| def main() -> None: | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| task = args.task.upper() | |
| n_classes = N_CLASSES[task] | |
| positive_class = DEFAULT_POSITIVE_CLASS[task] if args.positive_class is None else args.positive_class | |
| if not 0 <= positive_class < n_classes: | |
| raise ValueError(f"positive_class={positive_class} is invalid for task={task} with {n_classes} classes") | |
| dataset = AREDSDataset( | |
| args.test_csv, | |
| args.image_root, | |
| task, | |
| transform=AlbumentationsTransform(get_val_transforms(args.image_size)), | |
| ) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=device.type == "cuda", | |
| ) | |
| model = DeepSeeNet( | |
| n_classes=n_classes, | |
| backbone=args.backbone, | |
| pretrained=False, | |
| ).to(device) | |
| checkpoint = torch.load(args.checkpoint, map_location=device) | |
| model.load_state_dict(checkpoint["model"]) | |
| pred_dict = collect_predictions(model, loader, device) | |
| y_true = pred_dict["labels"] | |
| y_pred = pred_dict["preds"] | |
| probs = pred_dict["probs"] | |
| metrics = compute_metrics(y_true, y_pred, probs, n_classes=n_classes, positive_class=positive_class) | |
| metrics["loss"] = float(pred_dict["loss"]) | |
| units = None | |
| if args.bootstrap_unit_column: | |
| df_for_units = pd.read_csv(args.test_csv) | |
| if args.bootstrap_unit_column not in df_for_units.columns: | |
| raise KeyError( | |
| f"--bootstrap-unit-column {args.bootstrap_unit_column!r} not found in {args.test_csv}. " | |
| f"Available columns: {list(df_for_units.columns)}" | |
| ) | |
| if len(df_for_units) != len(y_true): | |
| raise ValueError( | |
| "CSV length does not match dataset length. " | |
| f"CSV rows={len(df_for_units)}, dataset rows={len(y_true)}" | |
| ) | |
| units = df_for_units[args.bootstrap_unit_column].to_numpy() | |
| rng = np.random.default_rng(args.seed) | |
| bs_indices = make_bootstrap_indices( | |
| n=len(y_true), | |
| n_iters=args.bootstrap_iters, | |
| rng=rng, | |
| units=units, | |
| ) | |
| def metric_fn(idx: np.ndarray) -> dict[str, float]: | |
| out = compute_metrics( | |
| y_true[idx], | |
| y_pred[idx], | |
| probs[idx], | |
| n_classes=n_classes, | |
| positive_class=positive_class, | |
| ) | |
| out.pop("loss", None) | |
| return out | |
| ci = bootstrap_ci(metric_fn, bs_indices) | |
| metrics_with_ci = combine_with_ci(metrics, ci) | |
| cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes))) | |
| endpoint_cm = confusion_matrix( | |
| (y_true == positive_class).astype(int), | |
| (y_pred == positive_class).astype(int), | |
| labels=[0, 1], | |
| ) | |
| meta = { | |
| "task": task, | |
| "endpoint": ENDPOINT_NAME[task], | |
| "positive_class": int(positive_class), | |
| "n_classes": int(n_classes), | |
| "n_samples": int(len(y_true)), | |
| "bootstrap_iters": int(args.bootstrap_iters), | |
| "bootstrap_unit_column": args.bootstrap_unit_column, | |
| } | |
| print(f"\nTask: {task} | endpoint: {ENDPOINT_NAME[task]} | positive_class={positive_class}") | |
| print_metric_table(metrics_with_ci) | |
| print("\nConfusion matrix (rows=true, cols=pred):") | |
| print(cm) | |
| print("\nBinary confusion matrix (rows=true, cols=pred):") | |
| print(endpoint_cm) | |
| if args.output_dir: | |
| output_dir = Path(args.output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| with (output_dir / "metrics.json").open("w") as f: | |
| json.dump({"meta": meta, "metrics": metrics_with_ci}, f, indent=2) | |
| pd.DataFrame(cm).to_csv(output_dir / "confusion_matrix.csv", index=False) | |
| pd.DataFrame(endpoint_cm, index=["true_neg", "true_pos"], columns=["pred_neg", "pred_pos"]).to_csv( | |
| output_dir / "endpoint_confusion_matrix.csv" | |
| ) | |
| pred_df = pd.read_csv(args.test_csv) | |
| if len(pred_df) == len(y_true): | |
| pred_df = pred_df.copy() | |
| else: | |
| pred_df = pd.DataFrame(index=np.arange(len(y_true))) | |
| pred_df["y_true"] = y_true | |
| pred_df["y_pred"] = y_pred | |
| pred_df[f"y_true_{ENDPOINT_NAME[task]}"] = (y_true == positive_class).astype(int) | |
| pred_df[f"y_pred_{ENDPOINT_NAME[task]}"] = (y_pred == positive_class).astype(int) | |
| for c in range(n_classes): | |
| pred_df[f"prob_class_{c}"] = probs[:, c] | |
| pred_df.to_csv(output_dir / "predictions.csv", index=False) | |
| print(f"\nSaved outputs to: {output_dir}") | |
| if __name__ == "__main__": | |
| main() | |