From b44de0e11281c723ec426f8bec8ca0897ecfe3ff Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Nov 2021 21:34:53 +0100 Subject: Remove VQVAE stuff, did not work... --- text_recognizer/models/vqvae.py | 45 ----------------------------------------- 1 file changed, 45 deletions(-) delete mode 100644 text_recognizer/models/vqvae.py (limited to 'text_recognizer/models/vqvae.py') diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py deleted file mode 100644 index 4898852..0000000 --- a/text_recognizer/models/vqvae.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): - """A PyTorch Lightning model for transformer networks.""" - - commitment: float = attr.ib(default=0.25) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("train/commitment_loss", commitment_loss) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - self.log("val/commitment_loss", commitment_loss) - self.log("val/loss", loss, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("test/commitment_loss", commitment_loss) - self.log("test/loss", loss) -- cgit v1.2.3-70-g09d2