diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-07 20:10:54 +0100 |
| commit | ff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch) | |
| tree | afee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer/networks/cnn_transformer.py | |
| parent | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff) | |
Minor updates.
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
| -rw-r--r-- | src/text_recognizer/networks/cnn_transformer.py | 47 |
1 files changed, 37 insertions, 10 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py index b2b74b3..caa73e3 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/src/text_recognizer/networks/cnn_transformer.py @@ -1,12 +1,13 @@ """A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple -from einops import rearrange +from einops import rearrange, repeat import torch from torch import nn from torch import Tensor from text_recognizer.networks.transformer import PositionalEncoding, Transformer +from text_recognizer.networks.util import activation_function from text_recognizer.networks.util import configure_backbone @@ -24,15 +25,21 @@ class CNNTransformer(nn.Module): expansion_dim: int, dropout_rate: float, trg_pad_index: int, + max_len: int, backbone: str, backbone_args: Optional[Dict] = None, activation: str = "gelu", ) -> None: super().__init__() self.trg_pad_index = trg_pad_index + self.vocab_size = vocab_size self.backbone = configure_backbone(backbone, backbone_args) - self.character_embedding = nn.Embedding(vocab_size, hidden_dim) - self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) + + self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + nn.init.normal_(self.character_embedding.weight, std=0.02) self.adaptive_pool = ( nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None @@ -48,7 +55,11 @@ class CNNTransformer(nn.Module): activation, ) - self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),) + self.head = nn.Sequential( + # nn.Linear(hidden_dim, hidden_dim * 2), + # activation_function(activation), + nn.Linear(hidden_dim, vocab_size), + ) def _create_trg_mask(self, trg: Tensor) -> Tensor: # Move this outside the transformer. @@ -96,7 +107,21 @@ class CNNTransformer(nn.Module): else: src = rearrange(src, "b c h w -> b (w h) c") - src = self.position_encoding(src) + b, t, _ = src.shape + + # Insert sos and eos token. + # sos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 2]).long().to(src.device) + # ) + # eos_token = self.character_embedding( + # torch.Tensor([self.vocab_size - 1]).long().to(src.device) + # ) + + # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1) + # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1) + # src = torch.cat((sos_tokens, src, eos_tokens), dim=1) + # src = torch.cat((sos_tokens, src), dim=1) + src += self.src_position_embedding[:, :t] return src @@ -111,20 +136,22 @@ class CNNTransformer(nn.Module): """ trg = self.character_embedding(trg.long()) - trg = self.position_encoding(trg) + trg = self.trg_position_encoding(trg) return trg - def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor: + def decode_image_features( + self, image_features: Tensor, trg: Optional[Tensor] = None + ) -> Tensor: """Takes images features from the backbone and decodes them with the transformer.""" trg_mask = self._create_trg_mask(trg) trg = self.target_embedding(trg) - out = self.transformer(h, trg, trg_mask=trg_mask) + out = self.transformer(image_features, trg, trg_mask=trg_mask) logits = self.head(out) return logits def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: """Forward pass with CNN transfomer.""" - h = self.extract_image_features(x) - logits = self.decode_image_features(h, trg) + image_features = self.extract_image_features(x) + logits = self.decode_image_features(image_features, trg) return logits |