From 3ab82ad36bce6fa698a13a029a0694b75a5947b7 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 6 Aug 2021 02:42:45 +0200 Subject: Fix VQVAE into en/decoder, bug in wandb artifact code uploading --- text_recognizer/networks/vqvae/attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'text_recognizer/networks/vqvae/attention.py') diff --git a/text_recognizer/networks/vqvae/attention.py b/text_recognizer/networks/vqvae/attention.py index 5a6b3ce..78a2cc9 100644 --- a/text_recognizer/networks/vqvae/attention.py +++ b/text_recognizer/networks/vqvae/attention.py @@ -7,7 +7,7 @@ import torch.nn.functional as F from text_recognizer.networks.vqvae.norm import Normalize -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Convolutional attention.""" @@ -63,11 +63,12 @@ class Attention(nn.Module): B, C, H, W = q.shape q = q.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] k = k.reshape(B, C, H * W) # [B, C, HW] - energy = torch.bmm(q, k) * (C ** -0.5) + energy = torch.bmm(q, k) * (int(C) ** -0.5) attention = F.softmax(energy, dim=2) # Compute attention to which values - v = v.reshape(B, C, H * W).permute(0, 2, 1) # [B, HW, C] + v = v.reshape(B, C, H * W) + attention = attention.permute(0, 2, 1) # [B, HW, HW] out = torch.bmm(v, attention) out = out.reshape(B, C, H, W) out = self.proj(out) -- cgit v1.2.3-70-g09d2