pyprobound.base.Transform
- class Transform(name='')
Bases:
ComponentComponent that applies a transformation to a tensor.
Includes improved typing and caching outputs to avoid recomputation for transformations that appear multiple times in a loss module. See https://github.com/pytorch/pytorch/issues/45414 for typing information.
- __init__(name='')
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Methods
cache(fun)Decorator for a function to cache its output.
check_length_consistency()Checks that input lengths of Binding components are consistent.
components()Iterator of child components.
forward(seqs)A transformation applied to a sequence tensor.
freeze()Turns off gradient calculation for all parameters.
max_embedding_size()The maximum number of bytes needed to encode a sequence.
optim_procedure([ancestry, current_order])The sequential optimization procedure for all Binding components.
reload(checkpoint)Loads the model from a checkpoint file.
reload_from_state_dict(state_dict)Loads the model from a state dict.
save(checkpoint[, flank_lengths])Saves the model to a file with "state_dict" and "metadata" fields.
unfreeze([parameter])Turns on gradient calculation for the specified parameter.
Attributes
unfreezablealias of
Literal['all']Non-Inherited Members
- abstract forward(seqs)
A transformation applied to a sequence tensor.
- Return type:
Tensor
- classmethod cache(fun)
Decorator for a function to cache its output.
The decorator must be applied to every function call whose output will be used in the cached function - generally all forward definitions.
- Return type:
Callable[[TypeVar(ComponentT, bound= Component),Tensor],Tensor]