Spaces:
Running
Running
| """DeepSeeNet model definition.""" | |
| from torch import Tensor, nn | |
| try: | |
| import timm | |
| except ImportError: # pragma: no cover - handled when timm is absent. | |
| timm = None | |
| class DeepSeeNet(nn.Module): | |
| """DeepSeeNet risk-factor classifier in PyTorch. | |
| Args: | |
| n_classes: Number of output classes. | |
| backbone: Any timm model name that supports ``num_classes=0``. The | |
| default uses InceptionV3. | |
| pretrained: Load ImageNet weights for the backbone. | |
| dropout: Dropout probability used by the classifier head. | |
| freeze_backbone: If true, keep the backbone frozen and train only the | |
| classifier head. | |
| """ | |
| def __init__( | |
| self, | |
| n_classes: int = 2, | |
| backbone: str = "inception_v3", | |
| pretrained: bool = True, | |
| dropout: float = 0.5, | |
| freeze_backbone: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| if n_classes < 1: | |
| raise ValueError("n_classes must be positive") | |
| if timm is None: | |
| raise ImportError("timm is required to build DeepSeeNet") | |
| self.backbone_name = backbone | |
| self.backbone = timm.create_model( | |
| backbone, | |
| pretrained=pretrained, | |
| num_classes=0, | |
| global_pool="avg", | |
| ) | |
| in_features = self.backbone.num_features | |
| self.classifier = nn.Sequential( | |
| nn.Linear(in_features, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| nn.Linear(128, n_classes), | |
| ) | |
| if freeze_backbone: | |
| self.backbone.requires_grad_(False) | |
| def forward(self, x: Tensor) -> Tensor: | |
| features = self.backbone(x) | |
| return self.classifier(features) | |