pyprobound.fitting.Fit

class Fit(rnd, dataset, prediction, observation=<function Fit.<lambda>>, update_construct=False, train_offset=False, train_posbias=False, train_hill=False, max_split=None, batch_size=None, checkpoint='valmodel.pt', output='/dev/null', device=None, sampler=None, optimizer=<class 'torch.optim.lbfgs.LBFGS'>, optim_args=None, sampler_args=None, name='')

Bases: BaseFit

Curve fitting to independent validation data in linear space.

\[\text{observation} (y) \sim m \times \text{prediction} (\log Z) + b\]
scale

The scaling factor \(m\) (1 if not train_offset).

Type:

Tensor

intercept

The intercept \(b\) (0 if not train_offset).

Type:

Tensor

__init__(rnd, dataset, prediction, observation=<function Fit.<lambda>>, update_construct=False, train_offset=False, train_posbias=False, train_hill=False, max_split=None, batch_size=None, checkpoint='valmodel.pt', output='/dev/null', device=None, sampler=None, optimizer=<class 'torch.optim.lbfgs.LBFGS'>, optim_args=None, sampler_args=None, name='')

Initializes the curve fitting.

Parameters:
  • rnd (BaseRound | Aggregate) – A component containing an aggregate of different modes.

  • dataset (CountTable) – A CountTable with 1 to 3 columns, with the first column taken as the target; if 2 columns are provided, the second column is taken as a symmetrical error; if 3 columns are provided, the second is taken as the lower error and the third is taken as the upper error.

  • prediction (Callable[[Tensor], Tensor]) – A callable applied to the log aggregate \(\log Z\).

  • observation (Callable[[Tensor], Tensor]) – A callable applied to the target \(y\).

  • update_construct (bool) – Whether to reset experiment-specific parameters.

  • train_offset (bool) – Whether to train scaling and intercept parameters.

  • train_posbias (bool) – Whether to retrain positional bias profiles \(\omega\).

  • train_hill (bool) – Whether to train a Hill coefficient.

  • max_split (int | None) – Maximum number of sequences scored at a time (lower values reduce memory but increase computation time).

  • batch_size (int | None) – The number of sequences used to optimize the model at a time.

  • checkpoint (str | PathLike[str]) – The file where the model will be checkpointed to.

  • output (str | PathLike[str]) – The file where the optimization output will be written to.

  • device (str | None) – The device on which to perform optimization.

  • sampler (type[Sampler[CountBatch]] | None) – The sampler used when creating the dataloader.

  • optimizer (type[Optimizer]) – The optimizer used for optimization.

  • optim_args (MutableMapping[str, Any] | None) – Parameters passed to the optimizer. (Defaults to {“line_search_fn”:”strong_wolfe”} if available).

  • sampler_args (MutableMapping[str, Any] | None) – Parameters passed to the sampler.

  • name (str) – A string used to describe the validation dataset.

Methods

check_length_consistency()

Checks that input lengths of Binding components are consistent.

components()

Iterator of child components.

fit()

Fits experiment-specific parameters to the validation data.

forward(batches)

Calculates the multitask weighted loss and regularization.

freeze()

Turns off gradient calculation for all parameters.

get_setup_string()

A description used when printing the output of an optimizer.

log_aggregate(seqs)

Calculates the log aggregate \(\log Z_i\).

max_embedding_size()

The maximum number of bytes needed to encode a sequence.

negloglik(transform, batch)

Calculates the negative log-likelihood plus a normalization factor.

obs_pred(seqs, target)

Calculates the observed and predicted values used for the loss.

optim_procedure([ancestry, current_order])

The sequential optimization procedure for all Binding components.

plot([xlabel, ylabel, kernel, xlog, ylog, ...])

Plots predicted validation values with error bars and binning.

regularization(component)

Calculates parameter regularization.

reload(checkpoint)

Loads the model from a checkpoint file.

reload_from_state_dict(state_dict)

Loads the model from a state dict.

save(checkpoint[, flank_lengths])

Saves the model to a file with "state_dict" and "metadata" fields.

score(batch)

Wraps obs_pred, automatically managing devices.

unfreeze([parameter])

Turns on gradient calculation for the specified parameter.

Attributes

unfreezable

alias of Literal['all']

scale

intercept

Non-Inherited Members

obs_pred(seqs, target)

Calculates the observed and predicted values used for the loss.

Parameters:
  • seqs (Tensor) – A sequence tensor of shape \((\text{minibatch},\text{length})\) or \((\text{minibatch},\text{in_channels},\text{length})\).

  • target (Tensor) – A target tensor of shape \((\text{minibatch},1-3)\)

Return type:

tuple[Tensor, Tensor, Tensor | None, Tensor | None]

Returns:

A tuple of four tensors of shape \((\text{minibatch},)\), being \(\text{obs}(y)\), \(m \times \text{pred}(\log Z) + b\), \(\text{obs}(y - \text{lower error})\), and \(\text{obs}(y + \text{lower error})\).

plot(xlabel='Predicted', ylabel='Observed', kernel=1, xlog=True, ylog=True, labels=None, colors=None)

Plots predicted validation values with error bars and binning.

Parameters:
  • xlabel (str) – The x-axis label.

  • ylabel (str) – The y-axis label.

  • kernel (int) – The bin for average pooling of prediction-sorted sequences.

  • xlog (bool) – Whether to plot the x-axis in logarithmic scale.

  • ylog (bool) – Whether to plot the y-axis in logarithmic scale.

  • labels (list[str] | None) – The label for each data point drawn on the plot.

  • colors (list[str] | None) – The color for each data point drawn on the plot.

Return type:

None

fit()

Fits experiment-specific parameters to the validation data.

Return type:

None