diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
| commit | 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch) | |
| tree | 91dab598a94911e6147b996237e786dd47f11f2f /src/text_recognizer/networks/cnn_transformer.py | |
| parent | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff) | |
updates
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
| -rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index 7133c26..a2d7926 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -112,11 +112,11 @@ class CNNTransformer(nn.Module): if self.max_pool is not None: src = self.max_pool(src) - if self.adaptive_pool is not None: + if self.adaptive_pool is not None and len(src.shape) == 4: src = rearrange(src, "b c h w -> b w c h") src = self.adaptive_pool(src) src = src.squeeze(3) - else: + elif len(src.shape) == 4: src = rearrange(src, "b c h w -> b (h w) c") b, t, _ = src.shape |