"""PyTorch datasets and dataloaders for AREDS fundus images.""" from pathlib import Path from typing import Callable, Optional, Tuple, Union import pandas as pd import torch from PIL import Image from torch import Tensor from torch.utils.data import Dataset from torchvision import transforms TASKS = ("ADVAMD", "DRUS", "PIG") DEFAULT_TRANSFORM = transforms.Compose( [ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), ), ] ) class AREDSDataset(Dataset): def __init__( self, csv_path: Union[str, Path], image_root: Union[str, Path], task: str, transform: Optional[Callable[[Image.Image], Tensor]] = None, ) -> None: task = task.upper() if task not in TASKS: raise ValueError(f"task must be one of {TASKS}") self.image_root = Path(image_root) self.task = task self.transform = transform or DEFAULT_TRANSFORM self.data = pd.read_csv(csv_path) def __len__(self) -> int: return len(self.data) def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: row = self.data.iloc[index] image_path = self.image_root / row.pathname image = Image.open(image_path).convert("RGB") image = self.transform(image) label = torch.tensor(int(row[self.task]), dtype=torch.long) return image, label