From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 6 Aug 2021 02:42:45 +0200 Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading --- text_recognizer/networks/conv_transformer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'text_recognizer/networks/conv_transformer.py') diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index f3ba49d..b1a101e 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -4,7 +4,6 @@ from typing import Tuple from torch import nn, Tensor -from text_recognizer.networks.encoders.efficientnet import EfficientNet from text_recognizer.networks.transformer.layers import Decoder from text_recognizer.networks.transformer.positional_encodings import ( PositionalEncoding, @@ -18,15 +17,17 @@ class ConvTransformer(nn.Module): def __init__( self, input_dims: Tuple[int, int, int], + encoder_dim: int, hidden_dim: int, dropout_rate: float, num_classes: int, pad_index: Tensor, - encoder: EfficientNet, + encoder: nn.Module, decoder: Decoder, ) -> None: super().__init__() self.input_dims = input_dims + self.encoder_dim = encoder_dim self.hidden_dim = hidden_dim self.dropout_rate = dropout_rate self.num_classes = num_classes @@ -38,7 +39,7 @@ class ConvTransformer(nn.Module): # positional encoding. self.latent_encoder = nn.Sequential( nn.Conv2d( - in_channels=self.encoder.out_channels, + in_channels=self.encoder_dim, out_channels=self.hidden_dim, kernel_size=1, ), -- cgit v1.2.3-70-g09d2