diff --git a/cobra/evaluation/pigs_tables.py b/cobra/evaluation/pigs_tables.py index 4c58eaa..6cca2d0 100644 --- a/cobra/evaluation/pigs_tables.py +++ b/cobra/evaluation/pigs_tables.py @@ -8,9 +8,9 @@ import cobra.utils as utils def generate_pig_tables(basetable: pd.DataFrame, - id_column_name: str, target_column_name: str, - preprocessed_predictors: list) -> pd.DataFrame: + preprocessed_predictors: list, + id_column_name: str = None) -> pd.DataFrame: """Compute PIG tables for all predictors in preprocessed_predictors. The output is a DataFrame with columns ``variable``, ``label``, @@ -20,35 +20,41 @@ def generate_pig_tables(basetable: pd.DataFrame, ---------- basetable : pd.DataFrame Basetable to compute PIG tables from. - id_column_name : str - Name of the basetable column containing the IDs of the basetable rows - (e.g. customernumber). target_column_name : str Name of the basetable column containing the target values to predict. preprocessed_predictors: list List of basetable column names containing preprocessed predictors. - + id_column_name : str, default=None + Name of the basetable column containing the IDs of the basetable rows + (e.g. customernumber). Returns ------- pd.DataFrame DataFrame containing a PIG table for all predictors. """ + + #check if there is a id-column and define no_predictor accordingly + if id_column_name == None: + no_predictor = [target_column_name] + else: + no_predictor = [id_column_name, target_column_name] + + pigs = [ compute_pig_table(basetable, column_name, target_column_name, - id_column_name) + ) for column_name in sorted(preprocessed_predictors) - if column_name not in [id_column_name, target_column_name] + if column_name not in no_predictor ] - output = pd.concat(pigs) + output = pd.concat(pigs, ignore_index=True) return output def compute_pig_table(basetable: pd.DataFrame, predictor_column_name: str, - target_column_name: str, - id_column_name: str) -> pd.DataFrame: + target_column_name: str) -> pd.DataFrame: """Compute the PIG table of a given predictor for a given target. Parameters @@ -59,8 +65,6 @@ def compute_pig_table(basetable: pd.DataFrame, Predictor name of which to compute the pig table. target_column_name : str Name of the target variable. - id_column_name : str - Name of the id column (used to count population size). Returns ------- @@ -72,12 +76,18 @@ def compute_pig_table(basetable: pd.DataFrame, # group by the binned variable, compute the incidence # (= mean of the target for the given bin) and compute the bin size # (e.g. COUNT(id_column_name)). After that, rename the columns + res = (basetable.groupby(predictor_column_name) - .agg({target_column_name: "mean", id_column_name: "size"}) + .agg( + avg_target = (target_column_name, "mean"), + pop_size = (target_column_name, "size") + ) .reset_index() - .rename(columns={predictor_column_name: "label", - target_column_name: "avg_target", - id_column_name: "pop_size"})) + .rename( + columns={predictor_column_name: "label"} + ) + ) + # add the column name to a variable column # add the average incidence diff --git a/tests/preprocessing/test_pig_tables.py b/tests/preprocessing/test_pig_tables.py new file mode 100644 index 0000000..3b1e6a7 --- /dev/null +++ b/tests/preprocessing/test_pig_tables.py @@ -0,0 +1,56 @@ +import pytest + +import pandas as pd +from cobra.evaluation.pigs_tables import generate_pig_tables + +from typing import Optional + + +class TestPigTablesGeneration: + @pytest.mark.parametrize( + "id_col_name", [None, "col_id"] + ) # test None as this is the default value in generate pig tabels + def test_col_id(self, id_col_name: Optional[str]): + + # input + data = pd.DataFrame( + { + "col_id": [0, 1, 3, 4, 6], + "survived": [0, 1, 1, 0, 0], + "pclass": [3, 1, 1, 3, 1], + "sex": ["male", "female", "female", "male", "male"], + "age": [22.0, 38.0, 35.0, 35.0, 54.0], + } + ) + target = "survived" + prep_col = ["pclass", "sex", "age"] + + # output + out = generate_pig_tables( + basetable=data, + target_column_name=target, + preprocessed_predictors=prep_col, + id_column_name=id_col_name, + ) + + # expected + expected = pd.DataFrame( + { + "variable": [ + "age", + "age", + "age", + "age", + "pclass", + "pclass", + "sex", + "sex", + ], + "label": [22.0, 35.0, 38.0, 54.0, 1, 3, "female", "male"], + "pop_size": [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6], + "global_avg_target": [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4], + "avg_target": [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0], + } + ) + + pd.testing.assert_frame_equal(out, expected)