ladder.scripts.training.train_pyro#
- ladder.scripts.training.train_pyro(model, train_loader, test_loader, num_epochs=500, convergence_threshold=0.001, convergence_window=15, verbose=True, device=device(type='cpu'), optim_args=None)#
Runner for basic Pyro models.
Trains up to
num_epochsor until a new minimum is not attained that is lower than the older minimum byconvergence_thresholdforconvergence_windowepochs.- Parameters:
model (
Module) – The model to train.train_loader (
DataLoader) – Data loader for the training set.test_loader (
DataLoader) – Data loader for the test set.num_epochs (
int, default: 500) – Maximum number of epochs to run.convergence_threshold (
float, default: 1e-3) – Minimum improvement to decide on convergence.convergence_window (
int, default: 15) – Patience window for deciding on convergence.verbose (
bool, default: True) – IfTrue, prints out the loss at every epoch.device (
device) – Device object to run models on.optim_args (
dict, default: {“optimizer”: opt.Adam,”optim_args”: {“lr”: 1e-3, “eps”: 1e-2},”gamma”: 1,”milestones”: [1e10]}) – Arguments to be passed to:class:`pyro.optim.MultiStepLRfor fine tuning if needed.
- Returns: