DeepSeeNet / model.py
farrell236's picture
add src
b8c9192
"""DeepSeeNet model definition."""
from torch import Tensor, nn
try:
import timm
except ImportError: # pragma: no cover - handled when timm is absent.
timm = None
class DeepSeeNet(nn.Module):
"""DeepSeeNet risk-factor classifier in PyTorch.
Args:
n_classes: Number of output classes.
backbone: Any timm model name that supports ``num_classes=0``. The
default uses InceptionV3.
pretrained: Load ImageNet weights for the backbone.
dropout: Dropout probability used by the classifier head.
freeze_backbone: If true, keep the backbone frozen and train only the
classifier head.
"""
def __init__(
self,
n_classes: int = 2,
backbone: str = "inception_v3",
pretrained: bool = True,
dropout: float = 0.5,
freeze_backbone: bool = False,
) -> None:
super().__init__()
if n_classes < 1:
raise ValueError("n_classes must be positive")
if timm is None:
raise ImportError("timm is required to build DeepSeeNet")
self.backbone_name = backbone
self.backbone = timm.create_model(
backbone,
pretrained=pretrained,
num_classes=0,
global_pool="avg",
)
in_features = self.backbone.num_features
self.classifier = nn.Sequential(
nn.Linear(in_features, 256),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(256, 128),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(128, n_classes),
)
if freeze_backbone:
self.backbone.requires_grad_(False)
def forward(self, x: Tensor) -> Tensor:
features = self.backbone(x)
return self.classifier(features)