From ff9a21d333f11a42e67c1963ed67de9c0fda87c9 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 7 Jan 2021 20:10:54 +0100 Subject: Minor updates. --- src/text_recognizer/networks/metrics.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) (limited to 'src/text_recognizer/networks/metrics.py') diff --git a/src/text_recognizer/networks/metrics.py b/src/text_recognizer/networks/metrics.py index af9adb5..ffad792 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/src/text_recognizer/networks/metrics.py @@ -6,28 +6,13 @@ from torch import Tensor from text_recognizer.networks import greedy_decoder -def accuracy_ignore_pad( - output: Tensor, - target: Tensor, - pad_index: int = 79, - eos_index: int = 81, - seq_len: int = 97, -) -> float: - """Sets all predictions after eos to pad.""" - start_indices = torch.nonzero(target == eos_index, as_tuple=False).squeeze(1) - end_indices = torch.arange(seq_len, target.shape[0] + 1, seq_len) - for start, stop in zip(start_indices, end_indices): - output[start + 1 : stop] = pad_index - - return accuracy(output, target) - - -def accuracy(outputs: Tensor, labels: Tensor,) -> float: +def accuracy(outputs: Tensor, labels: Tensor, pad_index: int = 53) -> float: """Computes the accuracy. Args: outputs (Tensor): The output from the network. labels (Tensor): Ground truth labels. + pad_index (int): Padding index. Returns: float: The accuracy for the batch. @@ -36,6 +21,12 @@ def accuracy(outputs: Tensor, labels: Tensor,) -> float: _, predicted = torch.max(outputs, dim=-1) + # Mask out the pad tokens + mask = labels != pad_index + + predicted *= mask + labels *= mask + acc = (predicted == labels).sum().float() / labels.shape[0] acc = acc.item() return acc -- cgit v1.2.3-70-g09d2