From 684da19a2ca83ee61011c37e36fa71b9eeb5ca6a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 11 Sep 2023 22:12:25 +0200 Subject: Update encoder/decoder attention and forward pass --- text_recognizer/network/transformer/encoder.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) (limited to 'text_recognizer/network/transformer/encoder.py') diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 328a40c..1728c61 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -2,16 +2,15 @@ from torch import Tensor, nn from .attention import Attention -from .ff import FeedForward class Encoder(nn.Module): def __init__( self, dim: int, - inner_dim: int, heads: int, dim_head: int, + ff_mult: int, depth: int, dropout_rate: float = 0.0, ) -> None: @@ -19,17 +18,15 @@ class Encoder(nn.Module): self.norm = nn.LayerNorm(dim) self.layers = nn.ModuleList( [ - nn.ModuleList( - [ - Attention( - dim, - heads, - False, - dim_head, - dropout_rate, - ), - FeedForward(dim, inner_dim, dropout_rate), - ] + Attention( + dim=dim, + heads=heads, + causal=False, + dim_head=dim_head, + ff_mult=ff_mult, + dropout_rate=dropout_rate, + use_flash=True, + norm_context=False, ) for _ in range(depth) ] @@ -40,7 +37,6 @@ class Encoder(nn.Module): x: Tensor, ) -> Tensor: """Applies decoder block on input signals.""" - for self_attn, ff in self.layers: + for self_attn in self.layers: x = x + self_attn(x) - x = x + ff(x) return self.norm(x) -- cgit v1.2.3-70-g09d2