DeepSeeNet / dataloader.py
farrell236's picture
add src
b8c9192
"""PyTorch datasets and dataloaders for AREDS fundus images."""
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
import pandas as pd
import torch
from PIL import Image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
TASKS = ("ADVAMD", "DRUS", "PIG")
DEFAULT_TRANSFORM = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
),
]
)
class AREDSDataset(Dataset):
def __init__(
self,
csv_path: Union[str, Path],
image_root: Union[str, Path],
task: str,
transform: Optional[Callable[[Image.Image], Tensor]] = None,
) -> None:
task = task.upper()
if task not in TASKS:
raise ValueError(f"task must be one of {TASKS}")
self.image_root = Path(image_root)
self.task = task
self.transform = transform or DEFAULT_TRANSFORM
self.data = pd.read_csv(csv_path)
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
row = self.data.iloc[index]
image_path = self.image_root / row.pathname
image = Image.open(image_path).convert("RGB")
image = self.transform(image)
label = torch.tensor(int(row[self.task]), dtype=torch.long)
return image, label