ladder.scripts.training.train_pyro

Contents

ladder.scripts.training.train_pyro#

ladder.scripts.training.train_pyro(model, train_loader, test_loader, num_epochs=500, convergence_threshold=0.001, convergence_window=15, verbose=True, device=device(type='cpu'), optim_args=None)#

Runner for basic Pyro models.

Trains up to num_epochs or until a new minimum is not attained that is lower than the older minimum by convergence_threshold for convergence_window epochs.

Parameters:
  • model (Module) – The model to train.

  • train_loader (DataLoader) – Data loader for the training set.

  • test_loader (DataLoader) – Data loader for the test set.

  • num_epochs (int, default: 500) – Maximum number of epochs to run.

  • convergence_threshold (float, default: 1e-3) – Minimum improvement to decide on convergence.

  • convergence_window (int, default: 15) – Patience window for deciding on convergence.

  • verbose (bool, default: True) – If True, prints out the loss at every epoch.

  • device (device) – Device object to run models on.

  • optim_args (dict, default: {“optimizer”: opt.Adam,”optim_args”: {“lr”: 1e-3, “eps”: 1e-2},”gamma”: 1,”milestones”: [1e10]}) – Arguments to be passed to :class:`pyro.optim.MultiStepLR for fine tuning if needed.

Returns:

modelModule

The model object post-training.

loss_track_trainndarray

float array containing the training loss per epoch.

loss_track_testndarray

float array containing the test loss per epoch.