From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- text_recognizer/network/transformer/encoder.py | 46 ++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 text_recognizer/network/transformer/encoder.py (limited to 'text_recognizer/network/transformer/encoder.py') diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py new file mode 100644 index 0000000..ea4b0b3 --- /dev/null +++ b/text_recognizer/network/transformer/encoder.py @@ -0,0 +1,46 @@ +"""Transformer encoder module.""" +from torch import Tensor, nn + +from text_recognizer.network.transformer.attention import Attention +from text_recognizer.network.transformer.ff import FeedForward + + +class Encoder(nn.Module): + def __init__( + self, + dim: int, + inner_dim: int, + heads: int, + dim_head: int, + depth: int, + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim) + self.layers = nn.ModuleList( + [ + nn.ModuleList( + [ + Attention( + dim, + heads, + False, + dim_head, + dropout_rate, + ), + FeedForward(dim, inner_dim, dropout_rate), + ] + ) + for _ in range(depth) + ] + ) + + def forward( + self, + x: Tensor, + ) -> Tensor: + """Applies decoder block on input signals.""" + for self_attn, ff in self.layers: + x = x + self_attn(x) + x = x + ff(x) + return self.norm(x) -- cgit v1.2.3-70-g09d2