summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
commitff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch)
treeafee959135416fe92cf6df377e84fb0a9e9714a0 /src/text_recognizer/networks/cnn_transformer.py
parent25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff)
Minor updates.
Diffstat (limited to 'src/text_recognizer/networks/cnn_transformer.py')
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py47
1 files changed, 37 insertions, 10 deletions
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index b2b74b3..caa73e3 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -1,12 +1,13 @@
"""A CNN-Transformer for image to text recognition."""
from typing import Dict, Optional, Tuple
-from einops import rearrange
+from einops import rearrange, repeat
import torch
from torch import nn
from torch import Tensor
from text_recognizer.networks.transformer import PositionalEncoding, Transformer
+from text_recognizer.networks.util import activation_function
from text_recognizer.networks.util import configure_backbone
@@ -24,15 +25,21 @@ class CNNTransformer(nn.Module):
expansion_dim: int,
dropout_rate: float,
trg_pad_index: int,
+ max_len: int,
backbone: str,
backbone_args: Optional[Dict] = None,
activation: str = "gelu",
) -> None:
super().__init__()
self.trg_pad_index = trg_pad_index
+ self.vocab_size = vocab_size
self.backbone = configure_backbone(backbone, backbone_args)
- self.character_embedding = nn.Embedding(vocab_size, hidden_dim)
- self.position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+ self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+
+ self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ nn.init.normal_(self.character_embedding.weight, std=0.02)
self.adaptive_pool = (
nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
@@ -48,7 +55,11 @@ class CNNTransformer(nn.Module):
activation,
)
- self.head = nn.Sequential(nn.Linear(hidden_dim, vocab_size),)
+ self.head = nn.Sequential(
+ # nn.Linear(hidden_dim, hidden_dim * 2),
+ # activation_function(activation),
+ nn.Linear(hidden_dim, vocab_size),
+ )
def _create_trg_mask(self, trg: Tensor) -> Tensor:
# Move this outside the transformer.
@@ -96,7 +107,21 @@ class CNNTransformer(nn.Module):
else:
src = rearrange(src, "b c h w -> b (w h) c")
- src = self.position_encoding(src)
+ b, t, _ = src.shape
+
+ # Insert sos and eos token.
+ # sos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 2]).long().to(src.device)
+ # )
+ # eos_token = self.character_embedding(
+ # torch.Tensor([self.vocab_size - 1]).long().to(src.device)
+ # )
+
+ # sos_tokens = repeat(sos_token, "() h -> b h", b=b).unsqueeze(1)
+ # eos_tokens = repeat(eos_token, "() h -> b h", b=b).unsqueeze(1)
+ # src = torch.cat((sos_tokens, src, eos_tokens), dim=1)
+ # src = torch.cat((sos_tokens, src), dim=1)
+ src += self.src_position_embedding[:, :t]
return src
@@ -111,20 +136,22 @@ class CNNTransformer(nn.Module):
"""
trg = self.character_embedding(trg.long())
- trg = self.position_encoding(trg)
+ trg = self.trg_position_encoding(trg)
return trg
- def decode_image_features(self, h: Tensor, trg: Optional[Tensor] = None) -> Tensor:
+ def decode_image_features(
+ self, image_features: Tensor, trg: Optional[Tensor] = None
+ ) -> Tensor:
"""Takes images features from the backbone and decodes them with the transformer."""
trg_mask = self._create_trg_mask(trg)
trg = self.target_embedding(trg)
- out = self.transformer(h, trg, trg_mask=trg_mask)
+ out = self.transformer(image_features, trg, trg_mask=trg_mask)
logits = self.head(out)
return logits
def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
"""Forward pass with CNN transfomer."""
- h = self.extract_image_features(x)
- logits = self.decode_image_features(h, trg)
+ image_features = self.extract_image_features(x)
+ logits = self.decode_image_features(image_features, trg)
return logits