diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-13 18:12:35 +0200 |
| commit | 0901bb8172fe56caa3eba9e4bf96ae0b164f9292 (patch) | |
| tree | ad1b5964af91a5982fed59715f058586cd28f60d /text_recognizer/networks/quantizer/kmeans.py | |
| parent | 7be90f5f101d7ace7ff07180950dac4c11086ec1 (diff) | |
Remove quantizer
Diffstat (limited to 'text_recognizer/networks/quantizer/kmeans.py')
| -rw-r--r-- | text_recognizer/networks/quantizer/kmeans.py | 32 |
1 files changed, 0 insertions, 32 deletions
diff --git a/text_recognizer/networks/quantizer/kmeans.py b/text_recognizer/networks/quantizer/kmeans.py deleted file mode 100644 index a34c381..0000000 --- a/text_recognizer/networks/quantizer/kmeans.py +++ /dev/null @@ -1,32 +0,0 @@ -"""K-means clustering for embeddings.""" -from typing import Tuple - -from einops import repeat -import torch -from torch import Tensor - -from text_recognizer.networks.quantizer.utils import norm, sample_vectors - - -def kmeans( - samples: Tensor, num_clusters: int, num_iters: int = 10 -) -> Tuple[Tensor, Tensor]: - """Compute k-means clusters.""" - D = samples.shape[-1] - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - dists = samples @ means.t() - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, D).type_as(samples) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=D), samples) - new_means /= bins_min_clamped[..., None] - new_means = norm(new_means) - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins |