ladder.scripts.training.train_pyro_disjoint_param#
- ladder.scripts.training.train_pyro_disjoint_param(model, train_loader, test_loader, num_epochs=500, convergence_threshold=0.001, convergence_window=15, verbose=True, device=device(type='cpu'), warmup=0, classifier_aggression=0, optim_args=None)#
Runner for Patches, but can be used for other adversarial models.
Trains up to
num_epochsor until a new minimum is not attained that is lower than the older minimum byconvergence_thresholdforconvergence_windowepochs. Allows for setting different training routines for the adversarial loss.- Parameters:
model (
Module) – The model to train.train_loader (
DataLoader) – Data loader for the training set.test_loader (
DataLoader) – Data loader for the test set.num_epochs (
int, default: 500) – Maximum number of epochs to run.convergence_threshold (
float, default: 1e-3) – Minimum improvement to decide on convergence.convergence_window (
int, default: 15) – Patience window for deciding on convergence.verbose (
bool, default: True) – IfTrue, prints out the loss at every epoch.device (
device) – Device object to run models on.warmup (
int, default: 0) – Number of epochs to run the classifier before running the entire model.classifier_aggression (
int, default: 0) – Number of epochs the classifier takes independently between jointly trained epochs.optim_args (
dict, default: {“optim_args”: {“lr”: 1e-3, “eps”: 1e-2},”gamma”: 1,”milestones”: [1e10]}) – Arguments to be passed to:class:`torch.optim.lr_scheduler.MultiStepLRfor fine tuning if needed.
- Returns:
- model
Module The model object post-training.
- loss_track_train
ndarray floatarray containing the training loss per epoch.- loss_track_test
ndarray floatarray containing the test loss per epoch.- params_nonc_names
set strset containing model parameter names except the classifier.- params_c_names
set strset containing model parameter names for the classifier.
- model