"""DeepSeeNet model definition.""" from torch import Tensor, nn try: import timm except ImportError: # pragma: no cover - handled when timm is absent. timm = None class DeepSeeNet(nn.Module): """DeepSeeNet risk-factor classifier in PyTorch. Args: n_classes: Number of output classes. backbone: Any timm model name that supports ``num_classes=0``. The default uses InceptionV3. pretrained: Load ImageNet weights for the backbone. dropout: Dropout probability used by the classifier head. freeze_backbone: If true, keep the backbone frozen and train only the classifier head. """ def __init__( self, n_classes: int = 2, backbone: str = "inception_v3", pretrained: bool = True, dropout: float = 0.5, freeze_backbone: bool = False, ) -> None: super().__init__() if n_classes < 1: raise ValueError("n_classes must be positive") if timm is None: raise ImportError("timm is required to build DeepSeeNet") self.backbone_name = backbone self.backbone = timm.create_model( backbone, pretrained=pretrained, num_classes=0, global_pool="avg", ) in_features = self.backbone.num_features self.classifier = nn.Sequential( nn.Linear(in_features, 256), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(256, 128), nn.ReLU(inplace=True), nn.Dropout(dropout), nn.Linear(128, n_classes), ) if freeze_backbone: self.backbone.requires_grad_(False) def forward(self, x: Tensor) -> Tensor: features = self.backbone(x) return self.classifier(features)