From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/data/base_data_module.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'text_recognizer/data/base_data_module.py') diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 408ae36..fd914b6 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,12 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset @@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule): def __attrs_pre_init__(self) -> None: super().__init__() + mapping: AbstractMapping = attr.ib() batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) + pin_memory: bool = attr.ib(default=True) # Placeholders data_train: BaseDataset = attr.ib(init=False, default=None) @@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule): data_test: BaseDataset = attr.ib(init=False, default=None) dims: Tuple[int, ...] = attr.ib(init=False, default=None) output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) - mapping: Any = attr.ib(init=False, default=None) - inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule): return { "input_dim": self.dims, "output_dims": self.output_dims, - "mapping": self.mapping, } def prepare_data(self) -> None: @@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule): shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def val_dataloader(self) -> DataLoader: @@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def test_dataloader(self) -> DataLoader: @@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) -- cgit v1.2.3-70-g09d2