From 4a54d7e690897dd6e6c719fb908fd371a44c2952 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 24 Jan 2021 22:14:17 +0100 Subject: Many updates, cool stuff on the way. --- src/text_recognizer/models/transformer_model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'src/text_recognizer/models/transformer_model.py') diff --git a/src/text_recognizer/models/transformer_model.py b/src/text_recognizer/models/transformer_model.py index 12e497f..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/src/text_recognizer/models/transformer_model.py @@ -6,9 +6,9 @@ import torch from torch import nn from torch import Tensor from torch.utils.data import Dataset -from torchvision.transforms import ToTensor from text_recognizer.datasets import EmnistMapper +import text_recognizer.datasets.transforms as transforms from text_recognizer.models.base import Model from text_recognizer.networks import greedy_decoder @@ -60,13 +60,19 @@ class TransformerModel(Model): eos_token=self.eos_token, lower=self.lower, ) - self.tensor_transform = ToTensor() - + self.tensor_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize(mean=[0.912], std=[0.168])] + ) self.softmax = nn.Softmax(dim=2) @torch.no_grad() def _generate_sentence(self, image: Tensor) -> Tuple[List, float]: src = self.network.extract_image_features(image) + + # Added for vqvae transformer. + if isinstance(src, Tuple): + src = src[0] + memory = self.network.encoder(src) confidence_of_predictions = [] -- cgit v1.2.3-70-g09d2