lenskit.graphs.lightgcn.LightGCNTrainer#
- class lenskit.graphs.lightgcn.LightGCNTrainer(scorer, data, options)#
Bases:
lenskit.training.ModelTrainerProtocol implemented by iterative trainers for models. Models that implement
UsesTrainerwill return an object implementing this protocol from theircreate_trainer()method.This protocol only defines the core aspects of training a model. Trainers should also implement
ParameterContainerto allow training to be checkpointed and resumed.It is also a good idea for the trainer to be pickleable, but the parameter container interface is the primary mechanism for checkpointing.
- Stability:
- Full (see Stability Levels).
- Parameters:
scorer (LightGCNScorer)
data (lenskit.data.Dataset)
options (lenskit.training.TrainingOptions)
- scorer: LightGCNScorer#
- data: lenskit.data.Dataset#
- options: lenskit.training.TrainingOptions#
- model: torch_geometric.nn.LightGCN#
- edges: torch.Tensor#
- optimizer: torch.optim.Optimizer#
- train_epoch()#
Perform one epoch of the training process, optionally returning metrics on the training behavior. After each training iteration, the mmodel must be usable.
- finalize()#
Finish the training process, cleaning up any unneeded data structures and doing any finalization steps to the model.
The default implementation does nothing.
- abstractmethod batch_loss(mb_edges, scores)#
- Parameters:
mb_edges (torch.Tensor)
scores (torch.Tensor)
- Return type: