summaryrefslogtreecommitdiff
path: root/src/text_recognizer/models/line_ctc_model.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
commit4d7713746eb936832e84852e90292936b933e87d (patch)
tree2b2519d1d2ce53d4e1390590f52018d55dadbc7c /src/text_recognizer/models/line_ctc_model.py
parent1b3b8073a19f939d18a0bb85247eb0d99284f7cc (diff)
Transfomer added, many other changes.
Diffstat (limited to 'src/text_recognizer/models/line_ctc_model.py')
-rw-r--r--src/text_recognizer/models/line_ctc_model.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py
index 16eaed3..cdc2d8b 100644
--- a/src/text_recognizer/models/line_ctc_model.py
+++ b/src/text_recognizer/models/line_ctc_model.py
@@ -51,7 +51,7 @@ class LineCTCModel(Model):
self._mapper = EmnistMapper()
self.tensor_transform = ToTensor()
- def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor:
+ def criterion(self, output: Tensor, targets: Tensor) -> Tensor:
"""Computes the CTC loss.
Args:
@@ -82,11 +82,13 @@ class LineCTCModel(Model):
torch.Tensor(target_lengths).type(dtype=torch.long).to(self.device)
)
- return self.criterion(output, targets, input_lengths, target_lengths)
+ return self._criterion(output, targets, input_lengths, target_lengths)
@torch.no_grad()
def predict_on_image(self, image: Union[np.ndarray, Tensor]) -> Tuple[str, float]:
"""Predict on a single input."""
+ self.eval()
+
if image.dtype == np.uint8:
# Converts an image with range [0, 255] with to Pytorch Tensor with range [0, 1].
image = self.tensor_transform(image)
@@ -110,6 +112,6 @@ class LineCTCModel(Model):
log_probs, _ = log_probs.max(dim=2)
predicted_characters = "".join(raw_pred[0])
- confidence_of_prediction = torch.exp(log_probs.sum()).item()
+ confidence_of_prediction = torch.exp(-log_probs.sum()).item()
return predicted_characters, confidence_of_prediction