summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/losses.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
commit4d7713746eb936832e84852e90292936b933e87d (patch)
tree2b2519d1d2ce53d4e1390590f52018d55dadbc7c /src/text_recognizer/networks/losses.py
parent1b3b8073a19f939d18a0bb85247eb0d99284f7cc (diff)
Transfomer added, many other changes.
Diffstat (limited to 'src/text_recognizer/networks/losses.py')
-rw-r--r--src/text_recognizer/networks/losses.py31
1 files changed, 0 insertions, 31 deletions
diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py
deleted file mode 100644
index 73e0641..0000000
--- a/src/text_recognizer/networks/losses.py
+++ /dev/null
@@ -1,31 +0,0 @@
-"""Implementations of custom loss functions."""
-from pytorch_metric_learning import distances, losses, miners, reducers
-from torch import nn
-from torch import Tensor
-
-
-class EmbeddingLoss:
- """Metric loss for training encoders to produce information-rich latent embeddings."""
-
- def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
- self.distance = distances.CosineSimilarity()
- self.reducer = reducers.ThresholdReducer(low=0)
- self.loss_fn = losses.TripletMarginLoss(
- margin=margin, distance=self.distance, reducer=self.reducer
- )
- self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)
-
- def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
- """Computes the metric loss for the embeddings based on their labels.
-
- Args:
- embeddings (Tensor): The laten vectors encoded by the network.
- labels (Tensor): Labels of the embeddings.
-
- Returns:
- Tensor: The metric loss for the embeddings.
-
- """
- hard_pairs = self.miner(embeddings, labels)
- loss = self.loss_fn(embeddings, labels, hard_pairs)
- return loss