diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 7fc67b1650d896212f1dccac89dac18cb5c00c8e..61d1a960449a35720b12cfddaa5fc0f4b4720d00 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -41,13 +41,25 @@ class AbstractTraining: implementations can focus on the computations per batch and not iterating over the dataset, storing snapshots, etc. - :param epochs: the number of epochs to train - :param dataset: the dataset to use for training - :param batch_size: the batch size used for training - :param device: the device on which to perform computations (cpu or cuda) - :param snapshot_dir: the directory where snapshots are stored - :param snapshot_epoch: at each of these epochs a snapshot is stored - :param logger: optional logger to print results + Parameters + ---------- + epochs: int + the number of epochs to train + dataset: MuMapDataset + the dataset to use for training + batch_size: int + the batch size used for training + device: torch.device + the device on which to perform computations (cpu or cuda) + snapshot_dir: str + the directory where snapshots are stored + snapshot_epoch: int + at each of these epochs a snapshot is stored + early_stopping: int, optional + if defined, training is stopped if the validation loss did not improve + for this many epochs + logger: Logger, optional + optional logger to print results """ def __init__( @@ -56,9 +68,9 @@ class AbstractTraining: dataset: MuMapDataset, batch_size: int, device: torch.device, - early_stopping: Optional[int], snapshot_dir: str, snapshot_epoch: int, + early_stopping: Optional[int], logger: Optional[Logger], ): self.epochs = epochs @@ -215,7 +227,10 @@ class AbstractTraining: """ Store snapshots of all models. - :param prefix: prefix for all stored snapshot files + Parameters + ---------- + prefix: str + prefix for all stored snapshot files """ for param in self.training_params: snapshot_file = os.path.join( @@ -255,9 +270,17 @@ class AbstractTraining: """ Implementation of training a single batch. - :param inputs: batch of input data - :param targets: batch of target data - :return: a number representing the loss + Parameters + ---------- + inputs: torch.Tensor + batch of input data + targets: torch.Tensor + batch of target data + + Returns + ------- + float + a number representing the loss """ return 0 @@ -265,8 +288,16 @@ class AbstractTraining: """ Implementation of evaluating a single batch. - :param inputs: batch of input data - :param targets: batch of target data - :return: a number representing the loss + Parameters + ---------- + inputs: torch.Tensor + batch of input data + targets: torch.Tensor + batch of target data + + Returns + ------- + float + a number representing the loss """ return 0