ladder.scripts.metrics.get_reproduction_error

ladder.scripts.metrics.get_reproduction_error#

ladder.scripts.metrics.get_reproduction_error(point_dataset, model, source=None, target=None, metric='corr', n_trials=10, lib_size=10000.0, batched=False, **kwargs)#

Calculates the model generative error for the given metric.

Wraps around gen_profile_reproduction() with the`_metric_func` corresponding to the chosen metric.

Parameters:
  • point_dataset – The Dataset object to be provided.

  • model (Predictive) – The model generator function wrapped in Predictive.

  • source (Tensor, optional) – Optional source condition encoding tensor to subset the provided Dataset input.

  • target (Tensor, optional) – Optional target conditon encoding tensor to subset the provided Dataset input.

  • metric (Literal["chamfer", "rmse", "swd", "corr"], default: “corr”) – Metric to be considered for evaluation.

  • n_trials (int, default: 10) – Number of times to repeat the generative process. Must be positive.

  • lib_size (float, default: 1e4) – Library size for normalized profiles.

  • batched (bool, default: False) – If True, assumes batch is concatenated to the inputs.

  • **kwargs (dict, optional) – Additional keyword arguments to be passed to gen_profile_reproduction().

Returns:

preds_mean_errorfloat

Mean of the value for the metric calculated across n_trials repetitions.

preds_var_errorfloat

Variance of the value for the metric calculated across n_trials repetitions.

pred_profilesTensor

Normalized pseudo-bulk profiles for the points generated.

predsTensor

Points generated by the model.