Skip to content
Snippets Groups Projects
Commit 9a9d4100 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

renaming of distance training class

parent f0a65512
No related branches found
No related tags found
No related merge requests found
import os
from typing import Dict
from logging import Logger
from typing import Optional
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger
from mu_map.training.lib import TrainingParams, AbstractTraining
from mu_map.training.loss import WeightedLoss
class Training(AbstractTraining):
class DistanceTraining(AbstractTraining):
def __init__(
self,
epochs: int,
......@@ -20,7 +19,7 @@ class Training(AbstractTraining):
snapshot_epoch: int,
params: TrainingParams,
loss_func: WeightedLoss,
logger,
logger: Optional[Logger] = None,
):
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
self.training_params.append(params)
......@@ -42,6 +41,7 @@ class Training(AbstractTraining):
if __name__ == "__main__":
import argparse
import os
import random
import sys
......@@ -74,7 +74,7 @@ if __name__ == "__main__":
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
default="data/second/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
......@@ -240,7 +240,7 @@ if __name__ == "__main__":
criterion = WeightedLoss.from_str(args.loss_func)
training = Training(
training = DistanceTraining(
epochs=args.epochs,
dataset=dataset,
batch_size=args.batch_size,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment