summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer/attention.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transformer/attention.py')
-rw-r--r--text_recognizer/networks/transformer/attention.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 3df5333..fca260d 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -1,12 +1,10 @@
"""Implementes the attention module for the transformer."""
from typing import Optional, Tuple
-from einops import rearrange
import torch
-from torch import einsum
-from torch import nn
-from torch import Tensor
import torch.nn.functional as F
+from einops import rearrange
+from torch import Tensor, einsum, nn
from text_recognizer.networks.transformer.embeddings.rotary import (
RotaryEmbedding,
@@ -35,7 +33,7 @@ class Attention(nn.Module):
self.dropout_rate = dropout_rate
self.rotary_embedding = rotary_embedding
- self.scale = self.dim ** -0.5
+ self.scale = self.dim**-0.5
inner_dim = self.num_heads * self.dim_head
self.to_q = nn.Linear(self.dim, inner_dim, bias=False)