From 01d6e5fc066969283df99c759609df441151e9c5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 6 Jun 2021 23:19:35 +0200 Subject: Working on fixing decoder transformer --- text_recognizer/networks/transformer/layers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) (limited to 'text_recognizer/networks/transformer/layers.py') diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index b2c703f..a44a525 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,8 +1,6 @@ """Generates the attention layer architecture.""" from functools import partial -from typing import Any, Dict, Optional, Type - -from click.types import Tuple +from typing import Any, Dict, Optional, Tuple, Type from torch import nn, Tensor @@ -30,6 +28,7 @@ class AttentionLayers(nn.Module): pre_norm: bool = True, ) -> None: super().__init__() + self.dim = dim attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) -- cgit v1.2.3-70-g09d2