Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
M
mu-map
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Tamino Huxohl
mu-map
Commits
2e31fc27
Commit
2e31fc27
authored
2 years ago
by
Tamino Huxohl
Browse files
Options
Downloads
Patches
Plain Diff
write doc for training lib
parent
bc02d82e
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
mu_map/training/lib.py
+82
-4
82 additions, 4 deletions
mu_map/training/lib.py
with
82 additions
and
4 deletions
mu_map/training/lib.py
+
82
−
4
View file @
2e31fc27
"""
Module functioning as a library for training related code.
"""
from
dataclasses
import
dataclass
from
logging
import
Logger
import
os
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
List
,
Optional
import
sys
import
torch
...
...
@@ -13,6 +16,12 @@ from mu_map.logging import get_logger
@dataclass
class
TrainingParams
:
"""
Dataclass to bundle parameters related to the optimization of
a single model. This includes a name, the model itself and an
optimizer. Optionally, a learning rate scheduler can be added.
"""
name
:
str
model
:
torch
.
nn
.
Module
optimizer
:
torch
.
optim
.
Optimizer
...
...
@@ -20,6 +29,26 @@ class TrainingParams:
class
AbstractTraining
:
"""
Abstract implementation of a training.
An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
In addition, training parameters for all models need to be added to the
`self.training_params` list as this is used to put models in the according mode
as well as using the optimizer and learning rate scheduler.
This abstract class implement a common training procedure so that
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
:param epochs: the number of epochs to train
:param dataset: the dataset to use for training
:param batch_size: the batch size used for training
:param device: the device on which to perform computations (cpu or cuda)
:param snapshot_dir: the directory where snapshots are stored
:param snapshot_epoch: at each of these epochs a snapshot is stored
:param logger: optional logger to print results
"""
def
__init__
(
self
,
epochs
:
int
,
...
...
@@ -42,7 +71,7 @@ class AbstractTraining:
logger
if
logger
is
not
None
else
get_logger
(
name
=
self
.
__class__
.
__name__
)
)
self
.
training_params
=
[]
self
.
training_params
:
List
[
TrainingParams
]
=
[]
self
.
data_loaders
=
dict
(
[
(
...
...
@@ -60,6 +89,16 @@ class AbstractTraining:
)
def
run
(
self
)
->
float
:
"""
Implementation of a training run.
For each epoch:
1. Train the model
2. Evaluate the model on the validation split
3. If applicable, store a snapshot
The validation loss is also kept track of to keep a snapshot
which achieves a minimal loss.
"""
loss_val_min
=
sys
.
maxsize
for
epoch
in
range
(
1
,
self
.
epochs
+
1
):
str_epoch
=
f
"
{
str
(
epoch
)
:
>
{
len
(
str
(
self
.
epochs
))
}}
"
...
...
@@ -97,11 +136,19 @@ class AbstractTraining:
for
param
in
self
.
training_params
:
param
.
optimizer
.
step
()
def
_train_epoch
(
self
):
def
_train_epoch
(
self
)
->
float
:
"""
Implementation of the training in a single epoch.
:return: a number representing the training loss
"""
# activate gradients
torch
.
set_grad_enabled
(
True
)
# set models into training mode
for
param
in
self
.
training_params
:
param
.
model
.
train
()
# iterate of all batches in the training dataset
loss
=
0.0
data_loader
=
self
.
data_loaders
[
"
train
"
]
for
i
,
(
inputs
,
targets
)
in
enumerate
(
data_loader
):
...
...
@@ -110,22 +157,33 @@ class AbstractTraining:
end
=
"
\r
"
,
)
# move data to according device
inputs
=
inputs
.
to
(
self
.
device
)
targets
=
targets
.
to
(
self
.
device
)
# zero grad optimizers
for
param
in
self
.
training_params
:
param
.
optimizer
.
zero_grad
()
loss
=
loss
+
self
.
_train_batch
(
inputs
,
targets
)
# step optimizers
self
.
_after_train_batch
()
return
loss
/
len
(
data_loader
)
def
_eval_epoch
(
self
):
def
_eval_epoch
(
self
)
->
float
:
"""
Implementation of the evaluation in a single epoch.
:return: a number representing the validation loss
"""
# deactivate gradients
torch
.
set_grad_enabled
(
False
)
# set models into evaluation mode
for
param
in
self
.
training_params
:
param
.
model
.
eval
()
# iterate of all batches in the validation dataset
loss
=
0.0
data_loader
=
self
.
data_loaders
[
"
validation
"
]
for
i
,
(
inputs
,
targets
)
in
enumerate
(
data_loader
):
...
...
@@ -134,6 +192,7 @@ class AbstractTraining:
end
=
"
\r
"
,
)
# move data to according device
inputs
=
inputs
.
to
(
self
.
device
)
targets
=
targets
.
to
(
self
.
device
)
...
...
@@ -141,6 +200,11 @@ class AbstractTraining:
return
loss
/
len
(
data_loader
)
def
store_snapshot
(
self
,
prefix
:
str
):
"""
Store snapshots of all models.
:param prefix: prefix for all stored snapshot files
"""
for
param
in
self
.
training_params
:
snapshot_file
=
os
.
path
.
join
(
self
.
snapshot_dir
,
f
"
{
prefix
}
_
{
param
.
name
.
lower
()
}
.pth
"
...
...
@@ -149,7 +213,21 @@ class AbstractTraining:
torch
.
save
(
param
.
model
.
state_dict
(),
snapshot_file
)
def
_train_batch
(
self
,
inputs
:
torch
.
Tensor
,
targets
:
torch
.
Tensor
)
->
float
:
"""
Implementation of training a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
"""
return
0
def
_eval_batch
(
self
,
inputs
:
torch
.
Tensor
,
targets
:
torch
.
Tensor
)
->
float
:
"""
Implementation of evaluating a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
"""
return
0
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment