ladder.scripts.metrics.get_reproduction_error#
- ladder.scripts.metrics.get_reproduction_error(point_dataset, model, source=None, target=None, metric='corr', n_trials=10, lib_size=10000.0, batched=False, **kwargs)#
Calculates the model generative error for the given metric.
Wraps around
gen_profile_reproduction()with the`_metric_func` corresponding to the chosen metric.- Parameters:
point_dataset – The
Datasetobject to be provided.model (
Predictive) – The model generator function wrapped inPredictive.source (
Tensor, optional) – Optional source condition encoding tensor to subset the providedDatasetinput.target (
Tensor, optional) – Optional target conditon encoding tensor to subset the providedDatasetinput.metric (
Literal["chamfer", "rmse", "swd", "corr"], default: “corr”) – Metric to be considered for evaluation.n_trials (
int, default: 10) – Number of times to repeat the generative process. Must be positive.lib_size (
float, default: 1e4) – Library size for normalized profiles.batched (
bool, default: False) – IfTrue, assumes batch is concatenated to the inputs.**kwargs (
dict, optional) – Additional keyword arguments to be passed togen_profile_reproduction().
- Returns:
- preds_mean_error
float Mean of the value for the metric calculated across
n_trialsrepetitions.- preds_var_error
float Variance of the value for the metric calculated across
n_trialsrepetitions.- pred_profiles
Tensor Normalized pseudo-bulk profiles for the points generated.
- preds
Tensor Points generated by the model.
- preds_mean_error