From 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 24 Feb 2021 22:00:29 +0100 Subject: updates --- src/text_recognizer/networks/cnn_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/text_recognizer/networks/cnn_transformer.py') diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 7133c26..a2d7926 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -112,11 +112,11 @@ class CNNTransformer(nn.Module): if self.max_pool is not None: src = self.max_pool(src) - if self.adaptive_pool is not None: + if self.adaptive_pool is not None and len(src.shape) == 4: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) - else: + elif len(src.shape) == 4: src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape -- cgit v1.2.3-70-g09d2