Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
M
mu-map
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Tamino Huxohl
mu-map
Commits
77851a54
Commit
77851a54
authored
2 years ago
by
Tamino Huxohl
Browse files
Options
Downloads
Patches
Plain Diff
update comment style in training lib
parent
6029718e
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
mu_map/training/lib.py
+46
-15
46 additions, 15 deletions
mu_map/training/lib.py
with
46 additions
and
15 deletions
mu_map/training/lib.py
+
46
−
15
View file @
77851a54
...
...
@@ -41,13 +41,25 @@ class AbstractTraining:
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
:param epochs: the number of epochs to train
:param dataset: the dataset to use for training
:param batch_size: the batch size used for training
:param device: the device on which to perform computations (cpu or cuda)
:param snapshot_dir: the directory where snapshots are stored
:param snapshot_epoch: at each of these epochs a snapshot is stored
:param logger: optional logger to print results
Parameters
----------
epochs: int
the number of epochs to train
dataset: MuMapDataset
the dataset to use for training
batch_size: int
the batch size used for training
device: torch.device
the device on which to perform computations (cpu or cuda)
snapshot_dir: str
the directory where snapshots are stored
snapshot_epoch: int
at each of these epochs a snapshot is stored
early_stopping: int, optional
if defined, training is stopped if the validation loss did not improve
for this many epochs
logger: Logger, optional
optional logger to print results
"""
def
__init__
(
...
...
@@ -56,9 +68,9 @@ class AbstractTraining:
dataset
:
MuMapDataset
,
batch_size
:
int
,
device
:
torch
.
device
,
early_stopping
:
Optional
[
int
],
snapshot_dir
:
str
,
snapshot_epoch
:
int
,
early_stopping
:
Optional
[
int
],
logger
:
Optional
[
Logger
],
):
self
.
epochs
=
epochs
...
...
@@ -215,7 +227,10 @@ class AbstractTraining:
"""
Store snapshots of all models.
:param prefix: prefix for all stored snapshot files
Parameters
----------
prefix: str
prefix for all stored snapshot files
"""
for
param
in
self
.
training_params
:
snapshot_file
=
os
.
path
.
join
(
...
...
@@ -255,9 +270,17 @@ class AbstractTraining:
"""
Implementation of training a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return
0
...
...
@@ -265,8 +288,16 @@ class AbstractTraining:
"""
Implementation of evaluating a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return
0
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment