From 49ca6ade1a19f7f9c702171537fe4be0dfcda66d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 25 Aug 2023 23:19:14 +0200 Subject: Rename and add flash atten --- text_recognizer/network/transformer/norm.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 text_recognizer/network/transformer/norm.py (limited to 'text_recognizer/network/transformer/norm.py') diff --git a/text_recognizer/network/transformer/norm.py b/text_recognizer/network/transformer/norm.py new file mode 100644 index 0000000..2737754 --- /dev/null +++ b/text_recognizer/network/transformer/norm.py @@ -0,0 +1,22 @@ +"""Normalization layers for transformers. + +Copied from lucidrains: + https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py + +""" +import torch +from torch import Tensor, nn +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + """Root mean square layer normalization.""" + + def __init__(self, heads: int, dim: int) -> None: + super().__init__() + self.scale = dim**-0.5 + self.gamma = nn.Parameter(torch.ones(heads, 1, dim)) + + def forward(self, x: Tensor) -> Tensor: + """Applies normalization.""" + return F.normalize(x, dim=-1) * self.scale * self.gamma -- cgit v1.2.3-70-g09d2