From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 7 Dec 2020 22:54:04 +0100 Subject: Segmentation working! --- src/text_recognizer/networks/residual_network.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) (limited to 'src/text_recognizer/networks/residual_network.py') diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 6405192..e397224 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -7,7 +7,6 @@ import torch from torch import nn from torch import Tensor -from text_recognizer.networks.stn import SpatialTransformerNetwork from text_recognizer.networks.util import activation_function @@ -209,12 +208,10 @@ class ResidualNetworkEncoder(nn.Module): activation: str = "relu", block: Type[nn.Module] = BasicBlock, levels: int = 1, - stn: bool = False, *args, **kwargs ) -> None: super().__init__() - self.stn = SpatialTransformerNetwork() if stn else None self.block_sizes = ( block_sizes if isinstance(block_sizes, list) else [block_sizes] * levels ) @@ -231,7 +228,7 @@ class ResidualNetworkEncoder(nn.Module): ), nn.BatchNorm2d(self.block_sizes[0]), activation_function(self.activation), - nn.MaxPool2d(kernel_size=2, stride=2, padding=1), + # nn.MaxPool2d(kernel_size=2, stride=2, padding=1), ) self.blocks = self._configure_blocks(block) @@ -275,8 +272,6 @@ class ResidualNetworkEncoder(nn.Module): # If batch dimenstion is missing, it needs to be added. if len(x.shape) == 3: x = x.unsqueeze(0) - if self.stn is not None: - x = self.stn(x) x = self.gate(x) x = self.blocks(x) return x -- cgit v1.2.3-70-g09d2