From e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Tue, 8 Sep 2020 23:14:23 +0200 Subject: IAM datasets implemented. --- src/training/trainer/callbacks/lr_schedulers.py | 52 +++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'src/training/trainer/callbacks/lr_schedulers.py') diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index ba2226a..bb41d2d 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,6 +1,7 @@ """Callbacks for learning rate schedulers.""" from typing import Callable, Dict, List, Optional, Type +from torch.optim.swa_utils import update_bn from training.trainer.callbacks import Callback from text_recognizer.models import Model @@ -95,3 +96,54 @@ class OneCycleLR(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() + + +class CosineAnnealingLR(Callback): + """Callback for Cosine Annealing.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class SWA(Callback): + """Stochastic Weight Averaging callback.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.swa_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.swa_start = self.model.swa_start + self.swa_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if epoch > self.swa_start: + self.model.swa_network.update_parameters(self.model.network) + self.swa_scheduler.step() + else: + self.lr_scheduler.step() + + def on_fit_end(self) -> None: + """Update batch norm statistics for the swa model at the end of training.""" + if self.model.swa_network: + update_bn( + self.model.val_dataloader(), + self.model.swa_network, + device=self.model.device, + ) -- cgit v1.2.3-70-g09d2