From 4e44486aa0e87459bed4b0fe423b16e59c76c1a0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 03:25:28 +0200 Subject: Move stems to transforms --- text_recognizer/data/transforms/paragraph.py | 66 ++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 text_recognizer/data/transforms/paragraph.py (limited to 'text_recognizer/data/transforms/paragraph.py') diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py new file mode 100644 index 0000000..39e1e59 --- /dev/null +++ b/text_recognizer/data/transforms/paragraph.py @@ -0,0 +1,66 @@ +"""Iam paragraph stem class.""" +import torchvision.transforms as T + +import text_recognizer.metadata.iam_paragraphs as metadata +from text_recognizer.data.stems.image import ImageStem + + +IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH +IMAGE_SHAPE = metadata.IMAGE_SHAPE + +MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH + + +class ParagraphStem(ImageStem): + """A stem for handling images that contain a paragraph of text.""" + + def __init__( + self, + augment=False, + color_jitter_kwargs=None, + random_affine_kwargs=None, + random_perspective_kwargs=None, + gaussian_blur_kwargs=None, + sharpness_kwargs=None, + ): + super().__init__() + + if not augment: + self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)]) + else: + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "shear": 6, + "scale": (0.95, 1), + "interpolation": T.InterpolationMode.BILINEAR, + } + if random_perspective_kwargs is None: + random_perspective_kwargs = { + "distortion_scale": 0.2, + "p": 0.5, + "interpolation": T.InterpolationMode.BILINEAR, + } + if gaussian_blur_kwargs is None: + gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} + if sharpness_kwargs is None: + sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} + + self.pil_transform = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomCrop( + size=IMAGE_SHAPE, + padding=None, + pad_if_needed=True, + fill=0, + padding_mode="constant", + ), + T.RandomAffine(**random_affine_kwargs), + T.RandomPerspective(**random_perspective_kwargs), + T.GaussianBlur(**gaussian_blur_kwargs), + T.RandomAdjustSharpness(**sharpness_kwargs), + ] + ) -- cgit v1.2.3-70-g09d2