ladder.scripts.training.train_pyro_disjoint_param

ladder.scripts.training.train_pyro_disjoint_param#

ladder.scripts.training.train_pyro_disjoint_param(model, train_loader, test_loader, num_epochs=500, convergence_threshold=0.001, convergence_window=15, verbose=True, device=device(type='cpu'), warmup=0, classifier_aggression=0, optim_args=None)#

Runner for Patches, but can be used for other adversarial models.

Trains up to num_epochs or until a new minimum is not attained that is lower than the older minimum by convergence_threshold for convergence_window epochs. Allows for setting different training routines for the adversarial loss.

Parameters:
  • model (Module) – The model to train.

  • train_loader (DataLoader) – Data loader for the training set.

  • test_loader (DataLoader) – Data loader for the test set.

  • num_epochs (int, default: 500) – Maximum number of epochs to run.

  • convergence_threshold (float, default: 1e-3) – Minimum improvement to decide on convergence.

  • convergence_window (int, default: 15) – Patience window for deciding on convergence.

  • verbose (bool, default: True) – If True, prints out the loss at every epoch.

  • device (device) – Device object to run models on.

  • warmup (int, default: 0) – Number of epochs to run the classifier before running the entire model.

  • classifier_aggression (int, default: 0) – Number of epochs the classifier takes independently between jointly trained epochs.

  • optim_args (dict, default: {“optim_args”: {“lr”: 1e-3, “eps”: 1e-2},”gamma”: 1,”milestones”: [1e10]}) – Arguments to be passed to :class:`torch.optim.lr_scheduler.MultiStepLR for fine tuning if needed.

Returns:

modelModule

The model object post-training.

loss_track_trainndarray

float array containing the training loss per epoch.

loss_track_testndarray

float array containing the test loss per epoch.

params_nonc_namesset

str set containing model parameter names except the classifier.

params_c_namesset

str set containing model parameter names for the classifier.