Model Training#
Interfaces and support for model training.
- class lenskit.training.TrainingOptions(retrain=True, device=None, rng=None, environment=<factory>, torch_profiler=None)#
Bases:
objectOptions and context settings that govern model training.
- Parameters:
- retrain: bool = True#
Whether the model should retrain if it is already trained. If
False, the component is allowed to skip training if it is already trained.In the common case of training pipelines, this flag is examined by
lenskit.pipeline.Pipeline.train(): if it isFalse, that method skips training any components that are already trained. Custom training code that wishes to avoid retraining models should checkTrainable.is_trained()instead of assuming that individual components will respect this flag.Note
This division of responsibility is to reduce the need for repetitive code: since implementing components seems to be a more common activity than logic that directly trains components (as opposed to pipelines) in ordinary LensKit use, making training code responsible for skipping retrain instead of requiring that of every component implementation allows individual implementations to be slightly simpler, without requiring separate options classes for pipeline and component training.
Changed in version 2026.1: Added the
is_trained()method that implementers must now also provide.
- device: str | None = None#
The device on which to train (e.g.
'cuda'). May be ignored if the model does not support the specified device.
- rng: RNGInput = None#
Random number generator to use for any randomness in the training process. This option contains any `SPEC 7`_-compatible random number generator specification; the
random_generator()will convert that into a NumPyGenerator.
- environment: dict[str, str]#
Additional training environment variables to control training behavior. Variables and their meanings are defined by individual components. Variables in this option override system environment variables when fetched with
envvar().
- torch_profiler: torch.profiler.profile | None = None#
Torch profiler for profiling training options.
- step_profiler()#
Signal to active profiler(s) that a new step has completed.
- random_generator(*, type: Literal['numpy'] = 'numpy') Generator#
- random_generator(*, type: Literal['torch']) torch.Generator
Obtain a random generator from the configured RNG or seed.
Note
Each call to this method will return a fresh generator from the same seed. Components should call it once at the beginning of their training procesess.
- Parameters:
type (Literal['numpy', 'torch'])
- Return type:
np.random.Generator | torch.Generator
- configured_device(*, gpu_default=False)#
Get the configured device, consulting environment variables and defaults if necessary. It looks for a device in the following order:
The
device, if specified on this object.The
LK_DEVICEenvironment variable.If CUDA is enabled and
gpu_defaultisTrue, return “cuda”The CPU.
- class lenskit.training.Trainable(*args, **kwargs)#
Bases:
ProtocolInterface for components and objects that can learn parameters from training data. It supports training and checking if a component has already been trained. This protocol only captures the concept of trainability; most trainable components should have other properties and behaviors as well:
They are usually components (
Component), with an appropriate__call__method.They should be pickleable.
- Stability:
- Full (see Stability Levels).
- train(data, options)#
Train the model to learn its parameters from a training dataset.
- Parameters:
data (Dataset) – The training dataset.
options (TrainingOptions) – The training options.
- Return type:
None
- class lenskit.training.UsesTrainer(config=None, **kwargs)#
Bases:
Component,ABC,TrainableBase class for models that implement
Trainablevia aModelTrainer.The component’s configuration must have an
epochsattribute defining the number of epochs to train.- Stability:
- Full (see Stability Levels).
- Parameters:
config (Any)
kwargs (Any)
- property expected_training_epochs: int | None#
Get the number of training epochs expected to run. The default implementation looks for an
epochsattribute on the configuration object (self.config).
- is_trained()#
Query if this component has already been trained.
- train(data, options=TrainingOptions(retrain=True, device=None, rng=None, environment={}, torch_profiler=None))#
Implementation of
Trainable.train()that uses the model trainer.- Parameters:
data (Dataset)
options (TrainingOptions)
- Return type:
None
- abstractmethod create_trainer(data, options)#
Create a model trainer to train this model.
- Parameters:
data (Dataset)
options (TrainingOptions)
- Return type:
- class lenskit.training.ModelTrainer#
Bases:
ABCProtocol implemented by iterative trainers for models. Models that implement
UsesTrainerwill return an object implementing this protocol from theircreate_trainer()method.This protocol only defines the core aspects of training a model. Trainers should also implement
ParameterContainerto allow training to be checkpointed and resumed.It is also a good idea for the trainer to be pickleable, but the parameter container interface is the primary mechanism for checkpointing.
- Stability:
- Full (see Stability Levels).
- abstractmethod train_epoch()#
Perform one epoch of the training process, optionally returning metrics on the training behavior. After each training iteration, the mmodel must be usable.
- finalize()#
Finish the training process, cleaning up any unneeded data structures and doing any finalization steps to the model.
The default implementation does nothing.
- Return type:
None