From b3fbfd72a8f647161685b28d20b4b61519d8a643 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 15 Apr 2024 21:49:51 +0200 Subject: Update transformer model --- text_recognizer/network/transformer/encoder.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'text_recognizer/network/transformer/encoder.py') diff --git a/text_recognizer/network/transformer/encoder.py b/text_recognizer/network/transformer/encoder.py index 1728c61..ce30372 100644 --- a/text_recognizer/network/transformer/encoder.py +++ b/text_recognizer/network/transformer/encoder.py @@ -13,6 +13,8 @@ class Encoder(nn.Module): ff_mult: int, depth: int, dropout_rate: float = 0.0, + use_rotary_emb: bool = False, + one_kv_head: bool = False, ) -> None: super().__init__() self.norm = nn.LayerNorm(dim) @@ -27,6 +29,8 @@ class Encoder(nn.Module): dropout_rate=dropout_rate, use_flash=True, norm_context=False, + use_rotary_emb=use_rotary_emb, + one_kv_head=one_kv_head, ) for _ in range(depth) ] -- cgit v1.2.3-70-g09d2