deepextractor.utils.checkpoints

Module Contents

deepextractor.utils.checkpoints.logger[source]
deepextractor.utils.checkpoints.HF_REPO_ID = 'tomdooney/deepextractor'[source]
deepextractor.utils.checkpoints.CHECKPOINT_BILBY = 'checkpoint_best_bilby_noise_base.pth.tar'[source]
deepextractor.utils.checkpoints.CHECKPOINT_REAL = 'checkpoint_best_real_noise_base.pth.tar'[source]
deepextractor.utils.checkpoints.save_checkpoint(state, filename='my_checkpoint.pth.tar')[source]

Save model and optimizer state as a checkpoint.

deepextractor.utils.checkpoints.load_checkpoint(checkpoint, model)[source]

Load model state from a checkpoint.

deepextractor.utils.checkpoints.load_optimizer(checkpoint, optimizer)[source]

Load optimizer state from a checkpoint.

deepextractor.utils.checkpoints.load_torch_model(model_name, model_dict, checkpoint_dir=None, device='cpu', checkpoint_filename=CHECKPOINT_BILBY)[source]

Load a pretrained PyTorch model.

Weights are resolved in this order:

  1. <checkpoint_dir>/<model_name>/<checkpoint_filename> (local)

  2. Hugging Face Hub (tomdooney/deepextractor) — downloaded and cached automatically.

Use the module-level constants CHECKPOINT_BILBY and CHECKPOINT_REAL for checkpoint_filename to select the correct variant.

Parameters:
  • model_name (str) – Name of the model (must be a key in model_dict).

  • model_dict (dict) – Mapping of model names to instantiated model objects.

  • checkpoint_dir (str | None) – Local directory to search first. Pass None to skip local lookup and always use Hugging Face Hub.

  • device (str | torch.device) – Device to load the model onto.

  • checkpoint_filename (str) – Checkpoint file name inside the model subdirectory.

Returns:

The model with loaded weights in eval mode, or None on failure.

Return type:

torch.nn.Module