Source code for recognite.model.classifier_ops

from typing import List

from torch import nn
from torch.nn import Sequential, Linear


[docs]def update_classifier( model: nn.Module, num_classes: int, bias: bool = False ) -> nn.Module: """Updates the classifier according to the given number of classes. The classifier (fully connected layer) contained in the given model is updated in-place such that it outputs ``num_classes`` elements. This function supports all models that are available from ``torchvision.models``. Args: model: The model to update. We assume that the last layer of the model is a classifier. This can be a single ``nn.Linear`` (like ResNet), or an ``nn.Sequential`` with an ``nn.Linear`` as final layer (like AlexNet). num_classes: The new number of classes for the classifier. bias: If ``True``, use a bias in the updated classifier. Else, don't. Returns: The updated model. """ clf_path = get_path_to_ultimate_classifier(model) clf_module = get_module_at_path(model, clf_path) if isinstance(clf_module, Linear): new_clf_module = Linear( in_features=clf_module.in_features, out_features=num_classes, bias=bias ) else: raise ValueError('Cannot find a final fully-connected layer in ' 'the model. Please use a different model.') set_module_at_path(model, clf_path, new_clf_module) return model
[docs]def split_backbone_classifier( model: nn.Module, ): """Splits the given model into a backbone and a classifier module. Args: model: The model to split. We assume that the last layer of the model is a classifier. This can be a single ``nn.Linear`` (like ResNet), or an ``nn.Sequential`` with an ``nn.Linear`` as final layer (like AlexNet). """ clf_path = get_path_to_ultimate_classifier(model) classifier = get_module_at_path(model, clf_path) backbone = model set_module_at_path(backbone, clf_path, nn.Identity()) return backbone, classifier
[docs]def get_path_to_ultimate_classifier( model: nn.Module, ): """Returns the path to the ultimate fully-connected layer. The path is returned as a list of strings, where each element is the attribute to get from the parent module to retrieve the corresponding module. To get the module from this path, use ``get_module_at_path(model, path)``. To change this module with another module, use ``set_module_at_path(module, path, new_module)``. Args: model: The model. We assume that the last layer of the model is a classifier. This can be a single ``nn.Linear`` (like ResNet) or an ``nn.Sequential`` with an ``nn.Linear`` as final layer (like AlexNet) """ named_children = list(model.named_children()) if len(named_children) == 0: ult_layer = model path_to_clf_layer = [] else: ult_name, ult_layer = named_children[-1] path_to_clf_layer = [ult_name] if isinstance(ult_layer, Linear): return path_to_clf_layer elif isinstance(ult_layer, Sequential): ult_subname, ult_sublayer = list(ult_layer.named_children())[-1] if isinstance(ult_sublayer, Linear): path_to_clf_layer.append(ult_subname) return path_to_clf_layer raise ValueError('Cannot find a final fully-connected layer in ' 'the model. Please use a different model.')
[docs]def get_ultimate_classifier( model: nn.Module, ): """Returns the ultimate fully-connected layer. Args: model: The model. We assume that the last layer of the model is a classifier. This can be a single ``nn.Linear`` (like ResNet), or an ``nn.Sequential`` with an ``nn.Linear`` as final layer (like AlexNet). """ path = get_path_to_ultimate_classifier(model) return get_module_at_path(model, path)
[docs]def get_module_at_path( model: nn.Module, path: List[str] ) -> nn.Module: """Returns the module at the given path in the model. Args: model: The model. path: List of strings, where each element is the attribute to get from the parent module to retrieve the corresponding module. Returns: The module at the given path. """ ret = model for p in path: ret = getattr(ret, p) return ret
[docs]def set_module_at_path( model: nn.Module, path: List[str], new_module: nn.Module ): """Replaces a module in a model. Args: model: The model. path: List of strings, where each element is the attribute to get from the parent module to retrieve the corresponding module. new_module: The module to put at the given path. """ parent = get_module_at_path(model, path[:-1]) setattr(parent, path[-1], new_module)