Source code for ml4chem.atomistic.models.base

from ml4chem.atomistic import Potentials
from abc import ABC, abstractmethod


[docs]class DeepLearningModel(ABC):
[docs] @abstractmethod def name(cls): """Return name of the class""" return cls.NAME
@abstractmethod def __init__(self, **kwargs): """Arguments needed to instantiate the model""" pass
[docs] @abstractmethod def prepare_model(self, **kwargs): """Prepare model for training or inference""" pass
[docs] @abstractmethod def forward(self, X): """Forward propagation pass""" pass
[docs]class DeepLearningTrainer(ABC, object):
[docs] def checkpoint_save(self, epoch, model, label=None, checkpoint=None, path=""): """Checkpoint saver A method that saves the checkpoint of a model during training. Parameters ---------- epoch : int Epoch number. model : object A DeepLearning object. label : str, optional String with checkpoint label, by default None. checkpoint : int, optional Set checkpoints. If set to 100, at each 100 epoch the model will be saved. Use -1 to save each epoch. Default is None. path : str, optional Path to save the checkpoint, by default "". """ if label is None: label = f"checkpoint-{epoch}" else: label = f"{label}-checkpoint-{epoch}" if checkpoint is None: pass elif checkpoint == -1: Potentials.save(model=model, label=label, path=path) elif epoch % checkpoint == 0: Potentials.save(model=model, label=label, path=path)