Spaces:
Sleeping
Sleeping
| """PyTorch datasets and dataloaders for AREDS fundus images.""" | |
| from pathlib import Path | |
| from typing import Callable, Optional, Tuple, Union | |
| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from torch import Tensor | |
| from torch.utils.data import Dataset | |
| from torchvision import transforms | |
| TASKS = ("ADVAMD", "DRUS", "PIG") | |
| DEFAULT_TRANSFORM = transforms.Compose( | |
| [ | |
| transforms.Resize(224), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| ), | |
| ] | |
| ) | |
| class AREDSDataset(Dataset): | |
| def __init__( | |
| self, | |
| csv_path: Union[str, Path], | |
| image_root: Union[str, Path], | |
| task: str, | |
| transform: Optional[Callable[[Image.Image], Tensor]] = None, | |
| ) -> None: | |
| task = task.upper() | |
| if task not in TASKS: | |
| raise ValueError(f"task must be one of {TASKS}") | |
| self.image_root = Path(image_root) | |
| self.task = task | |
| self.transform = transform or DEFAULT_TRANSFORM | |
| self.data = pd.read_csv(csv_path) | |
| def __len__(self) -> int: | |
| return len(self.data) | |
| def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: | |
| row = self.data.iloc[index] | |
| image_path = self.image_root / row.pathname | |
| image = Image.open(image_path).convert("RGB") | |
| image = self.transform(image) | |
| label = torch.tensor(int(row[self.task]), dtype=torch.long) | |
| return image, label | |