Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 25 additions & 72 deletions cobra/preprocessing/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import inspect
from datetime import datetime
import time
import logging
from random import shuffle

# third party imports
import pandas as pd
Expand All @@ -25,7 +27,6 @@
from cobra.preprocessing import TargetEncoder
from cobra.preprocessing import CategoricalDataProcessor

import logging
log = logging.getLogger(__name__)


Expand Down Expand Up @@ -338,103 +339,55 @@ def fit_transform(self, train_data: pd.DataFrame, continuous_vars: list,

@staticmethod
def train_selection_validation_split(data: pd.DataFrame,
target_column_name: str,
train_prop: float = 0.6,
selection_prop: float = 0.2,
validation_prop: float = 0.2,
stratify_split=True) -> pd.DataFrame:
"""Split dataset into train-selection-validation datasets and merge
them into one big DataFrame with an additional column "split"
indicating to which dataset the corresponding row belongs to.
validation_prop: float = 0.2)-> pd.DataFrame:
"""Adds `split` column with train/selection/validation values
to the dataset.

Parameters
----------
data : pd.DataFrame
Input dataset to split into train-selection and validation sets
target_column_name : str
Name of the target column
train_prop : float, optional
Percentage data to put in train set
selection_prop : float, optional
Percentage data to put in selection set
validation_prop : float, optional
Percentage data to put in validation set
stratify_split : bool, optional
Whether or not to stratify the train-test split

Returns
-------
pd.DataFrame
DataFrame with additional split column
"""

if train_prop + selection_prop + validation_prop != 1.0:
raise ValueError("The sum of train_prop, selection_prop and "
"validation_prop cannot differ from 1.0")

if train_prop == 0.0:
raise ValueError("train_prop cannot be zero!")

if selection_prop == 0.0:
raise ValueError("selection_prop cannot be zero!")

column_names = list(data.columns)

predictors = [col for col in column_names if col != target_column_name]

# for the first split, take sum of selection & validation pct as
# test pct
test_prop = selection_prop + validation_prop
# To further split our test set into selection + validation set,
# we have to modify validation pct because we only have test_prop of
# the data available anymore for further splitting!
validation_prop_modif = validation_prop / test_prop

X = data[predictors]
y = data[target_column_name]

stratify = None
if stratify_split:
stratify = y

X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=test_prop,
random_state=42,
stratify=stratify
)

df_train = pd.DataFrame(X_train, columns=predictors)
df_train[target_column_name] = y_train
df_train["split"] = "train"

# If there is no validation percentage, return train-selection sets
# only
if validation_prop == 0.0:
df_selection = pd.DataFrame(X_test, columns=predictors)
df_selection[target_column_name] = y_test
df_selection["split"] = "selection"

return (pd.concat([df_train, df_selection])
.reset_index(drop=True))

if stratify_split:
stratify = y_test

X_sel, X_val, y_sel, y_val = train_test_split(
X_test, y_test,
test_size=validation_prop_modif,
random_state=42,
stratify=stratify
)

df_selection = pd.DataFrame(X_sel, columns=predictors)
df_selection[target_column_name] = y_sel
df_selection["split"] = "selection"

df_validation = pd.DataFrame(X_val, columns=predictors)
df_validation[target_column_name] = y_val
df_validation["split"] = "validation"

return (pd.concat([df_train, df_selection, df_validation])
.reset_index(drop=True))
nrows = data.shape[0]
size_train = int(train_prop * nrows)
size_select = int(selection_prop * nrows)
size_valid = int(validation_prop * nrows)
correction = nrows - (size_train+size_select+size_valid)

split = ['train'] * size_train \
+ ['train'] * correction \
+ ['selection'] * size_select \
+ ['validation'] * size_valid

shuffle(split)

data['split'] = split

return data


def serialize_pipeline(self) -> dict:
"""Serialize the preprocessing pipeline by writing all its required
Expand Down
3 changes: 1 addition & 2 deletions docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ This will be taken care of by the ``PreProcessor`` class, which has a scikit-lea
# containing each of those values
basetable = preprocessor.train_selection_validation_split(
basetable,
target_column_name=target_column_name,
train_prop=0.6, selection_prop=0.2,
validation_prop=0.2)

Expand Down Expand Up @@ -222,4 +221,4 @@ Additionally, we can also compute the output needed to plot the so-called Predic
target_column_name=target_column_name,
preprocessed_predictors=predictor_list)
# Plot PIGs
plot_incidence(pig_tables, 'predictor_name', predictor_order)
plot_incidence(pig_tables, 'predictor_name', predictor_order)
13 changes: 3 additions & 10 deletions tests/preprocessing/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,13 @@ def test_train_selection_validation_split(self, train_prop: float,
data = pd.DataFrame(X, columns=[f"c{i+1}" for i in range(10)])
data.loc[:, "target"] = np.array([0] * 7 + [1] * 3)

# No stratified split here because sample size is to low to make
# it work. This feature is already well-tested in scikit-learn and
# needs no further testing here
actual = PreProcessor.train_selection_validation_split(data,
"target",
train_prop,
selection_prop,
validation_prop,
False)
validation_prop)

# check for the output schema
expected_schema = list(data.columns) + ["split"]
assert list(actual.columns) == expected_schema
assert list(actual.columns) == list(data.columns)

# check that total size of input & output is the same!
assert len(actual.index) == len(data.index)
Expand Down Expand Up @@ -79,10 +73,9 @@ def _test_train_selection_validation_split_error(self,
selection_prop: float,
error_msg: str):
df = pd.DataFrame()
cname = ""
with pytest.raises(ValueError, match=error_msg):
(PreProcessor
.train_selection_validation_split(df, cname,
.train_selection_validation_split(df,
train_prop=train_prop,
selection_prop=selection_prop,
validation_prop=0.1))
Expand Down