deepextractor.utils.checkpoints =============================== .. py:module:: deepextractor.utils.checkpoints Module Contents --------------- .. py:data:: logger .. py:data:: HF_REPO_ID :value: 'tomdooney/deepextractor' .. py:data:: CHECKPOINT_BILBY :value: 'checkpoint_best_bilby_noise_base.pth.tar' .. py:data:: CHECKPOINT_REAL :value: 'checkpoint_best_real_noise_base.pth.tar' .. py:function:: save_checkpoint(state, filename='my_checkpoint.pth.tar') Save model and optimizer state as a checkpoint. .. py:function:: load_checkpoint(checkpoint, model) Load model state from a checkpoint. .. py:function:: load_optimizer(checkpoint, optimizer) Load optimizer state from a checkpoint. .. py:function:: load_torch_model(model_name, model_dict, checkpoint_dir=None, device='cpu', checkpoint_filename=CHECKPOINT_BILBY) Load a pretrained PyTorch model. Weights are resolved in this order: 1. ``//`` (local) 2. Hugging Face Hub (``tomdooney/deepextractor``) — downloaded and cached automatically. Use the module-level constants ``CHECKPOINT_BILBY`` and ``CHECKPOINT_REAL`` for ``checkpoint_filename`` to select the correct variant. :param model_name: Name of the model (must be a key in ``model_dict``). :type model_name: str :param model_dict: Mapping of model names to instantiated model objects. :type model_dict: dict :param checkpoint_dir: Local directory to search first. Pass ``None`` to skip local lookup and always use Hugging Face Hub. :type checkpoint_dir: str | None :param device: Device to load the model onto. :type device: str | torch.device :param checkpoint_filename: Checkpoint file name inside the model subdirectory. :type checkpoint_filename: str :returns: The model with loaded weights in eval mode, or ``None`` on failure. :rtype: torch.nn.Module