Skip to content
Snippets Groups Projects
Commit 77851a54 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update comment style in training lib

parent 6029718e
No related branches found
No related tags found
No related merge requests found
......@@ -41,13 +41,25 @@ class AbstractTraining:
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
:param epochs: the number of epochs to train
:param dataset: the dataset to use for training
:param batch_size: the batch size used for training
:param device: the device on which to perform computations (cpu or cuda)
:param snapshot_dir: the directory where snapshots are stored
:param snapshot_epoch: at each of these epochs a snapshot is stored
:param logger: optional logger to print results
Parameters
----------
epochs: int
the number of epochs to train
dataset: MuMapDataset
the dataset to use for training
batch_size: int
the batch size used for training
device: torch.device
the device on which to perform computations (cpu or cuda)
snapshot_dir: str
the directory where snapshots are stored
snapshot_epoch: int
at each of these epochs a snapshot is stored
early_stopping: int, optional
if defined, training is stopped if the validation loss did not improve
for this many epochs
logger: Logger, optional
optional logger to print results
"""
def __init__(
......@@ -56,9 +68,9 @@ class AbstractTraining:
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
early_stopping: Optional[int],
snapshot_dir: str,
snapshot_epoch: int,
early_stopping: Optional[int],
logger: Optional[Logger],
):
self.epochs = epochs
......@@ -215,7 +227,10 @@ class AbstractTraining:
"""
Store snapshots of all models.
:param prefix: prefix for all stored snapshot files
Parameters
----------
prefix: str
prefix for all stored snapshot files
"""
for param in self.training_params:
snapshot_file = os.path.join(
......@@ -255,9 +270,17 @@ class AbstractTraining:
"""
Implementation of training a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return 0
......@@ -265,8 +288,16 @@ class AbstractTraining:
"""
Implementation of evaluating a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment