From 240f5e9f20032e82515fa66ce784619527d1041e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 8 Aug 2021 19:59:55 +0200 Subject: Add VQGAN and loss function --- text_recognizer/networks/vqvae/encoder.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'text_recognizer/networks/vqvae/encoder.py') diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index ad8f950..4a5c976 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Encoder(nn.Module): """A CNN encoder network.""" - def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + in_channels: int, + hidden_dim: int, + channels_multipliers: List[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.in_channels = in_channels self.hidden_dim = hidden_dim @@ -33,7 +40,7 @@ class Encoder(nn.Module): ] num_blocks = len(self.channels_multipliers) - channels_multipliers = (1, ) + self.channels_multipliers + channels_multipliers = (1,) + self.channels_multipliers activation_fn = activation_function(self.activation) for i in range(num_blocks): -- cgit v1.2.3-70-g09d2