DeepSeeNet / test.py
farrell236's picture
add src
b8c9192
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Callable
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
try:
from sklearn.metrics import (
accuracy_score,
cohen_kappa_score,
confusion_matrix,
recall_score,
roc_auc_score,
)
except ImportError as exc:
raise ImportError(
"This evaluation script needs scikit-learn. Install with: pip install scikit-learn"
) from exc
from augmentations import get_val_transforms
from dataloader import AREDSDataset
from model import DeepSeeNet
N_CLASSES = {
"ADVAMD": 2,
"DRUS": 3,
"PIG": 2,
}
DEFAULT_POSITIVE_CLASS = {
"ADVAMD": 1,
"DRUS": 2,
"PIG": 1,
}
ENDPOINT_NAME = {
"ADVAMD": "late_amd",
"DRUS": "large_drusen",
"PIG": "pigmentary_abnormality",
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--test-csv", required=True)
parser.add_argument("--image-root", required=True)
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--task", required=True, type=str.upper, choices=N_CLASSES)
parser.add_argument("--backbone", default="inception_v3")
parser.add_argument("--image-size", type=int, default=1024)
parser.add_argument("--batch-size", type=int, default=32)
parser.add_argument("--num-workers", type=int, default=16)
parser.add_argument("--positive-class", type=int, default=None)
parser.add_argument("--bootstrap-iters", type=int, default=2000)
parser.add_argument("--seed", type=int, default=123)
parser.add_argument("--bootstrap-unit-column", default=None)
parser.add_argument("--output-dir", default=None)
return parser.parse_args()
class AlbumentationsTransform:
def __init__(self, transform) -> None:
self.transform = transform
def __call__(self, image):
return self.transform(image=np.asarray(image))["image"]
@torch.no_grad()
def collect_predictions(model: torch.nn.Module, loader: DataLoader, device: torch.device) -> dict[str, np.ndarray | float]:
model.eval()
total_loss = 0.0
total_samples = 0
all_labels: list[np.ndarray] = []
all_logits: list[np.ndarray] = []
for images, labels in tqdm(loader, desc="test"):
images = images.to(device)
labels = labels.to(device)
logits = model(images)
if isinstance(logits, (tuple, list)):
logits = logits[0]
loss = F.cross_entropy(logits, labels)
batch_size = labels.size(0)
total_loss += loss.item() * batch_size
total_samples += batch_size
all_labels.append(labels.detach().cpu().numpy())
all_logits.append(logits.detach().cpu().numpy())
labels_np = np.concatenate(all_labels).astype(int)
logits_np = np.concatenate(all_logits, axis=0)
probs_np = torch.softmax(torch.from_numpy(logits_np), dim=1).numpy()
preds_np = probs_np.argmax(axis=1).astype(int)
return {
"loss": float(total_loss / max(total_samples, 1)),
"labels": labels_np,
"logits": logits_np,
"probs": probs_np,
"preds": preds_np,
}
def specificity_score(y_true_bin: np.ndarray, y_pred_bin: np.ndarray) -> float:
tn = np.sum((y_true_bin == 0) & (y_pred_bin == 0))
fp = np.sum((y_true_bin == 0) & (y_pred_bin == 1))
denom = tn + fp
return float(tn / denom) if denom else float("nan")
def safe_auc(y_true_bin: np.ndarray, y_score: np.ndarray) -> float:
if len(np.unique(y_true_bin)) < 2:
return float("nan")
return float(roc_auc_score(y_true_bin, y_score))
def compute_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
probs: np.ndarray,
n_classes: int,
positive_class: int,
) -> dict[str, float]:
y_true_bin = (y_true == positive_class).astype(int)
y_pred_bin = (y_pred == positive_class).astype(int)
pos_score = probs[:, positive_class]
metrics = {
"loss": float("nan"),
"exact_accuracy": float(accuracy_score(y_true, y_pred)),
"exact_kappa": float(cohen_kappa_score(y_true, y_pred)),
"overall_accuracy": float(accuracy_score(y_true_bin, y_pred_bin)),
"sensitivity": float(recall_score(y_true_bin, y_pred_bin, pos_label=1, zero_division=0)),
"specificity": specificity_score(y_true_bin, y_pred_bin),
"kappa": float(cohen_kappa_score(y_true_bin, y_pred_bin)),
"auc": safe_auc(y_true_bin, pos_score),
}
if n_classes > 2 and len(np.unique(y_true)) > 1:
try:
metrics["macro_ovr_auc"] = float(
roc_auc_score(y_true, probs, labels=list(range(n_classes)), multi_class="ovr", average="macro")
)
except ValueError:
metrics["macro_ovr_auc"] = float("nan")
return metrics
def make_bootstrap_indices(
n: int,
n_iters: int,
rng: np.random.Generator,
units: np.ndarray | None = None,
) -> list[np.ndarray]:
if n_iters <= 0:
return []
if units is None:
return [rng.integers(0, n, size=n) for _ in range(n_iters)]
unique_units = np.array(pd.unique(units))
row_indices_by_unit = {unit: np.where(units == unit)[0] for unit in unique_units}
out = []
for _ in range(n_iters):
sampled_units = rng.choice(unique_units, size=len(unique_units), replace=True)
out.append(np.concatenate([row_indices_by_unit[u] for u in sampled_units]))
return out
def bootstrap_ci(
metric_fn: Callable[[np.ndarray], dict[str, float]],
indices: list[np.ndarray],
) -> dict[str, dict[str, float]]:
if not indices:
return {}
values_by_metric: dict[str, list[float]] = {}
for idx in tqdm(indices, desc="bootstrap", leave=False):
vals = metric_fn(idx)
for key, value in vals.items():
values_by_metric.setdefault(key, []).append(value)
intervals: dict[str, dict[str, float]] = {}
for key, values in values_by_metric.items():
arr = np.asarray(values, dtype=float)
intervals[key] = {
"ci_low": float(np.nanpercentile(arr, 2.5)),
"ci_high": float(np.nanpercentile(arr, 97.5)),
}
return intervals
def combine_with_ci(metrics: dict[str, float], ci: dict[str, dict[str, float]]) -> dict[str, Any]:
out: dict[str, Any] = {}
for key, value in metrics.items():
out[key] = {"value": float(value)}
if key in ci:
out[key].update(ci[key])
return out
def print_metric_table(metrics_with_ci: dict[str, Any]) -> None:
print("\nMetrics")
print("-------")
for key in ["overall_accuracy", "sensitivity", "specificity", "kappa", "auc"]:
item = metrics_with_ci[key]
if "ci_low" in item:
print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
else:
print(f"{key:20s} {item['value']:.4f}")
print("\nClassifier metrics")
print("------------------")
for key in ["loss", "exact_accuracy", "exact_kappa", "macro_ovr_auc"]:
if key not in metrics_with_ci:
continue
item = metrics_with_ci[key]
if "ci_low" in item:
print(f"{key:20s} {item['value']:.4f} ({item['ci_low']:.4f}-{item['ci_high']:.4f})")
else:
print(f"{key:20s} {item['value']:.4f}")
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
task = args.task.upper()
n_classes = N_CLASSES[task]
positive_class = DEFAULT_POSITIVE_CLASS[task] if args.positive_class is None else args.positive_class
if not 0 <= positive_class < n_classes:
raise ValueError(f"positive_class={positive_class} is invalid for task={task} with {n_classes} classes")
dataset = AREDSDataset(
args.test_csv,
args.image_root,
task,
transform=AlbumentationsTransform(get_val_transforms(args.image_size)),
)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
)
model = DeepSeeNet(
n_classes=n_classes,
backbone=args.backbone,
pretrained=False,
).to(device)
checkpoint = torch.load(args.checkpoint, map_location=device)
model.load_state_dict(checkpoint["model"])
pred_dict = collect_predictions(model, loader, device)
y_true = pred_dict["labels"]
y_pred = pred_dict["preds"]
probs = pred_dict["probs"]
metrics = compute_metrics(y_true, y_pred, probs, n_classes=n_classes, positive_class=positive_class)
metrics["loss"] = float(pred_dict["loss"])
units = None
if args.bootstrap_unit_column:
df_for_units = pd.read_csv(args.test_csv)
if args.bootstrap_unit_column not in df_for_units.columns:
raise KeyError(
f"--bootstrap-unit-column {args.bootstrap_unit_column!r} not found in {args.test_csv}. "
f"Available columns: {list(df_for_units.columns)}"
)
if len(df_for_units) != len(y_true):
raise ValueError(
"CSV length does not match dataset length. "
f"CSV rows={len(df_for_units)}, dataset rows={len(y_true)}"
)
units = df_for_units[args.bootstrap_unit_column].to_numpy()
rng = np.random.default_rng(args.seed)
bs_indices = make_bootstrap_indices(
n=len(y_true),
n_iters=args.bootstrap_iters,
rng=rng,
units=units,
)
def metric_fn(idx: np.ndarray) -> dict[str, float]:
out = compute_metrics(
y_true[idx],
y_pred[idx],
probs[idx],
n_classes=n_classes,
positive_class=positive_class,
)
out.pop("loss", None)
return out
ci = bootstrap_ci(metric_fn, bs_indices)
metrics_with_ci = combine_with_ci(metrics, ci)
cm = confusion_matrix(y_true, y_pred, labels=list(range(n_classes)))
endpoint_cm = confusion_matrix(
(y_true == positive_class).astype(int),
(y_pred == positive_class).astype(int),
labels=[0, 1],
)
meta = {
"task": task,
"endpoint": ENDPOINT_NAME[task],
"positive_class": int(positive_class),
"n_classes": int(n_classes),
"n_samples": int(len(y_true)),
"bootstrap_iters": int(args.bootstrap_iters),
"bootstrap_unit_column": args.bootstrap_unit_column,
}
print(f"\nTask: {task} | endpoint: {ENDPOINT_NAME[task]} | positive_class={positive_class}")
print_metric_table(metrics_with_ci)
print("\nConfusion matrix (rows=true, cols=pred):")
print(cm)
print("\nBinary confusion matrix (rows=true, cols=pred):")
print(endpoint_cm)
if args.output_dir:
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
with (output_dir / "metrics.json").open("w") as f:
json.dump({"meta": meta, "metrics": metrics_with_ci}, f, indent=2)
pd.DataFrame(cm).to_csv(output_dir / "confusion_matrix.csv", index=False)
pd.DataFrame(endpoint_cm, index=["true_neg", "true_pos"], columns=["pred_neg", "pred_pos"]).to_csv(
output_dir / "endpoint_confusion_matrix.csv"
)
pred_df = pd.read_csv(args.test_csv)
if len(pred_df) == len(y_true):
pred_df = pred_df.copy()
else:
pred_df = pd.DataFrame(index=np.arange(len(y_true)))
pred_df["y_true"] = y_true
pred_df["y_pred"] = y_pred
pred_df[f"y_true_{ENDPOINT_NAME[task]}"] = (y_true == positive_class).astype(int)
pred_df[f"y_pred_{ENDPOINT_NAME[task]}"] = (y_pred == positive_class).astype(int)
for c in range(n_classes):
pred_df[f"prob_class_{c}"] = probs[:, c]
pred_df.to_csv(output_dir / "predictions.csv", index=False)
print(f"\nSaved outputs to: {output_dir}")
if __name__ == "__main__":
main()