Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions cobra/evaluation/pigs_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import cobra.utils as utils

def generate_pig_tables(basetable: pd.DataFrame,
id_column_name: str,
target_column_name: str,
preprocessed_predictors: list) -> pd.DataFrame:
preprocessed_predictors: list,
id_column_name: str = None) -> pd.DataFrame:
"""Compute PIG tables for all predictors in preprocessed_predictors.

The output is a DataFrame with columns ``variable``, ``label``,
Expand All @@ -20,35 +20,41 @@ def generate_pig_tables(basetable: pd.DataFrame,
----------
basetable : pd.DataFrame
Basetable to compute PIG tables from.
id_column_name : str
Name of the basetable column containing the IDs of the basetable rows
(e.g. customernumber).
target_column_name : str
Name of the basetable column containing the target values to predict.
preprocessed_predictors: list
List of basetable column names containing preprocessed predictors.

id_column_name : str, default=None
Name of the basetable column containing the IDs of the basetable rows
(e.g. customernumber).
Returns
-------
pd.DataFrame
DataFrame containing a PIG table for all predictors.
"""

#check if there is a id-column and define no_predictor accordingly
if id_column_name == None:
no_predictor = [target_column_name]
else:
no_predictor = [id_column_name, target_column_name]


pigs = [
compute_pig_table(basetable,
column_name,
target_column_name,
id_column_name)
)
for column_name in sorted(preprocessed_predictors)
if column_name not in [id_column_name, target_column_name]
if column_name not in no_predictor
]
output = pd.concat(pigs)
output = pd.concat(pigs, ignore_index=True)
return output


def compute_pig_table(basetable: pd.DataFrame,
predictor_column_name: str,
target_column_name: str,
id_column_name: str) -> pd.DataFrame:
target_column_name: str) -> pd.DataFrame:
"""Compute the PIG table of a given predictor for a given target.

Parameters
Expand All @@ -59,8 +65,6 @@ def compute_pig_table(basetable: pd.DataFrame,
Predictor name of which to compute the pig table.
target_column_name : str
Name of the target variable.
id_column_name : str
Name of the id column (used to count population size).

Returns
-------
Expand All @@ -72,12 +76,18 @@ def compute_pig_table(basetable: pd.DataFrame,
# group by the binned variable, compute the incidence
# (= mean of the target for the given bin) and compute the bin size
# (e.g. COUNT(id_column_name)). After that, rename the columns

res = (basetable.groupby(predictor_column_name)
.agg({target_column_name: "mean", id_column_name: "size"})
.agg(
avg_target = (target_column_name, "mean"),
pop_size = (target_column_name, "size")
)
.reset_index()
.rename(columns={predictor_column_name: "label",
target_column_name: "avg_target",
id_column_name: "pop_size"}))
.rename(
columns={predictor_column_name: "label"}
)
)


# add the column name to a variable column
# add the average incidence
Expand Down
56 changes: 56 additions & 0 deletions tests/preprocessing/test_pig_tables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest

import pandas as pd
from cobra.evaluation.pigs_tables import generate_pig_tables

from typing import Optional


class TestPigTablesGeneration:
@pytest.mark.parametrize(
"id_col_name", [None, "col_id"]
) # test None as this is the default value in generate pig tabels
def test_col_id(self, id_col_name: Optional[str]):

# input
data = pd.DataFrame(
{
"col_id": [0, 1, 3, 4, 6],
"survived": [0, 1, 1, 0, 0],
"pclass": [3, 1, 1, 3, 1],
"sex": ["male", "female", "female", "male", "male"],
"age": [22.0, 38.0, 35.0, 35.0, 54.0],
}
)
target = "survived"
prep_col = ["pclass", "sex", "age"]

# output
out = generate_pig_tables(
basetable=data,
target_column_name=target,
preprocessed_predictors=prep_col,
id_column_name=id_col_name,
)

# expected
expected = pd.DataFrame(
{
"variable": [
"age",
"age",
"age",
"age",
"pclass",
"pclass",
"sex",
"sex",
],
"label": [22.0, 35.0, 38.0, 54.0, 1, 3, "female", "male"],
"pop_size": [0.2, 0.4, 0.2, 0.2, 0.6, 0.4, 0.4, 0.6],
"global_avg_target": [0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4],
"avg_target": [0.0, 0.5, 1.0, 0.0, 0.6666666666666666, 0.0, 1.0, 0.0],
}
)

pd.testing.assert_frame_equal(out, expected)