From d691b548cd0b6fc4ea184d64261f633789fee021 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 7 Jan 2021 23:35:42 +0100 Subject: working on vq-vae --- src/text_recognizer/networks/vqvae/encoder.py | 64 +++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 src/text_recognizer/networks/vqvae/encoder.py (limited to 'src/text_recognizer/networks/vqvae/encoder.py') diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py new file mode 100644 index 0000000..60c4c43 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/encoder.py @@ -0,0 +1,64 @@ +"""CNN encoder for the VQ-VAE.""" + +from typing import List, Optional, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer + + +class _ResidualBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> None: + super().__init__() + self.block = [ + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + activation, + nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), + ] + + if dropout is not None: + self.block.append(dropout) + + self.block = nn.Sequential(*self.block) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) + + +class Encoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + beta: float = 0.25, + activation: str = "elu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + pass + # if dropout_rate: + # if activation == "selu": + # dropout = nn.AlphaDropout(p=dropout_rate) + # else: + # dropout = nn.Dropout(p=dropout_rate) + # else: + # dropout = None + + def _build_encoder(self) -> nn.Sequential: + # TODO: Continue to implement encoder. + pass -- cgit v1.2.3-70-g09d2