From 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 22 Apr 2021 08:15:58 +0200 Subject: Fixed training script, able to train vqvae --- text_recognizer/data/iam_extended_paragraphs.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'text_recognizer/data/iam_extended_paragraphs.py') diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index d2529b4..2380660 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -10,18 +10,27 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs class IAMExtendedParagraphs(BaseDataModule): def __init__( self, - batch_size: int = 128, + batch_size: int = 16, num_workers: int = 0, train_fraction: float = 0.8, augment: bool = True, + word_pieces: bool = False, ) -> None: super().__init__(batch_size, num_workers) self.iam_paragraphs = IAMParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - batch_size, num_workers, train_fraction, augment, + batch_size, + num_workers, + train_fraction, + augment, + word_pieces, ) self.dims = self.iam_paragraphs.dims -- cgit v1.2.3-70-g09d2