deepextractor.utils.checkpoints¶
Module Contents¶
- deepextractor.utils.checkpoints.CHECKPOINT_BILBY = 'checkpoint_best_bilby_noise_base.pth.tar'[source]¶
- deepextractor.utils.checkpoints.CHECKPOINT_REAL = 'checkpoint_best_real_noise_base.pth.tar'[source]¶
- deepextractor.utils.checkpoints.save_checkpoint(state, filename='my_checkpoint.pth.tar')[source]¶
Save model and optimizer state as a checkpoint.
- deepextractor.utils.checkpoints.load_checkpoint(checkpoint, model)[source]¶
Load model state from a checkpoint.
- deepextractor.utils.checkpoints.load_optimizer(checkpoint, optimizer)[source]¶
Load optimizer state from a checkpoint.
- deepextractor.utils.checkpoints.load_torch_model(model_name, model_dict, checkpoint_dir=None, device='cpu', checkpoint_filename=CHECKPOINT_BILBY)[source]¶
Load a pretrained PyTorch model.
Weights are resolved in this order:
<checkpoint_dir>/<model_name>/<checkpoint_filename>(local)Hugging Face Hub (
tomdooney/deepextractor) — downloaded and cached automatically.
Use the module-level constants
CHECKPOINT_BILBYandCHECKPOINT_REALforcheckpoint_filenameto select the correct variant.- Parameters:
model_name (str) – Name of the model (must be a key in
model_dict).model_dict (dict) – Mapping of model names to instantiated model objects.
checkpoint_dir (str | None) – Local directory to search first. Pass
Noneto skip local lookup and always use Hugging Face Hub.device (str | torch.device) – Device to load the model onto.
checkpoint_filename (str) – Checkpoint file name inside the model subdirectory.
- Returns:
The model with loaded weights in eval mode, or
Noneon failure.- Return type: