From 46a1472d33d3a4180798492e819f2ec02bc3b1a3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 28 Mar 2021 22:02:24 +0200 Subject: Add refactor of iam lines --- text_recognizer/data/base_dataset.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'text_recognizer/data/base_dataset.py') diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index a9e9c24..d00daaf 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -71,3 +71,16 @@ def convert_strings_to_labels( for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels + + +def split_dataset( + dataset: BaseDataset, fraction: float, seed: int +) -> Tuple[BaseDataset, BaseDataset]: + """Split dataset into two parts with fraction * size and (1 - fraction) * size.""" + if fraction >= 1.0: + raise ValueError("Fraction cannot be larger greater or equal to 1.0.") + split_1 = int(fraction * len(dataset)) + split_2 = len(dataset) - split_1 + return torch.utils.data.random_split( + dataset, [split_1, split_2], generator=torch.Generator().manual_seed(seed) + ) -- cgit v1.2.3-70-g09d2