diff --git a/cobra/preprocessing/preprocessor.py b/cobra/preprocessing/preprocessor.py index 7e23d67..3cd94c8 100644 --- a/cobra/preprocessing/preprocessor.py +++ b/cobra/preprocessing/preprocessor.py @@ -14,6 +14,7 @@ import inspect from datetime import datetime import time +import math import logging from random import shuffle @@ -361,9 +362,9 @@ def train_selection_validation_split(data: pd.DataFrame, pd.DataFrame DataFrame with additional split column """ - if train_prop + selection_prop + validation_prop != 1.0: + if not math.isclose(train_prop + selection_prop + validation_prop, 1.0): raise ValueError("The sum of train_prop, selection_prop and " - "validation_prop cannot differ from 1.0") + "validation_prop must be 1.0.") if train_prop == 0.0: raise ValueError("train_prop cannot be zero!") diff --git a/tests/preprocessing/test_preprocessor.py b/tests/preprocessing/test_preprocessor.py index c2e4fcd..80f6d73 100644 --- a/tests/preprocessing/test_preprocessor.py +++ b/tests/preprocessing/test_preprocessor.py @@ -16,13 +16,19 @@ def does_not_raise(): class TestPreProcessor: - @pytest.mark.parametrize(("train_prop, selection_prop, " - "validation_prop, expected_sizes"), + @pytest.mark.parametrize("train_prop, selection_prop, validation_prop, " + "expected_sizes", [(0.6, 0.2, 0.2, {"train": 6, "selection": 2, "validation": 2}), (0.7, 0.3, 0.0, {"train": 7, - "selection": 3})]) + "selection": 3}), + # Error "The sum of train_prop, selection_prop and + # validation_prop must be 1.0." should not be + # raised: + (0.7, 0.2, 0.1, {"train": 7, + "selection": 2, + "validation": 1})]) def test_train_selection_validation_split(self, train_prop: float, selection_prop: float, validation_prop: float, @@ -50,7 +56,7 @@ def test_train_selection_validation_split(self, train_prop: float, def test_train_selection_validation_split_error_wrong_prop(self): error_msg = ("The sum of train_prop, selection_prop and " - "validation_prop cannot differ from 1.0") + "validation_prop must be 1.0.") train_prop = 0.7 selection_prop = 0.3