Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions cobra/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,20 @@ class Evaluator():
probability_cutoff : float
probability cut off to convert probability scores to a binary score
roc_curve : dict
map containing true-positive-rate, false-positve-rate at various
map containing true-positive-rate, false-positive-rate at various
thresholds (also incl.)
n_bins : int, optional
defines the number of bins used to calculate the lift curve for
(by default 10, so deciles)
"""

def __init__(self, probability_cutoff: float=None,
lift_at: float=0.05):
lift_at: float=0.05,
n_bins: int = 10):

self.lift_at = lift_at
self.probability_cutoff = probability_cutoff
self.n_bins = n_bins

# Placeholder to store fitted output
self.scalar_metrics = None
Expand Down Expand Up @@ -85,7 +90,7 @@ def fit(self, y_true: np.ndarray, y_pred: np.ndarray):

self.roc_curve = {"fpr": fpr, "tpr": tpr, "thresholds": thresholds}
self.confusion_matrix = confusion_matrix(y_true, y_pred_b)
self.lift_curve = Evaluator._compute_lift_per_decile(y_true, y_pred)
self.lift_curve = Evaluator._compute_lift_per_bin(y_true, y_pred, self.n_bins)
self.cumulative_gains = Evaluator._compute_cumulative_gains(y_true,
y_pred)

Expand Down Expand Up @@ -199,8 +204,7 @@ def plot_confusion_matrix(self, path: str=None, dim: tuple=(12, 8),

plt.show()

def plot_cumulative_response_curve(self, path: str=None,
dim: tuple=(12, 8)):
def plot_cumulative_response_curve(self, path: str=None, dim: tuple=(12, 8)):
"""Plot cumulative response curve

Parameters
Expand Down Expand Up @@ -430,17 +434,21 @@ def _compute_cumulative_gains(y_true: np.ndarray,
return percentages, gains

@staticmethod
def _compute_lift_per_decile(y_true: np.ndarray,
y_pred: np.ndarray) -> tuple:
"""Compute lift of the model per decile, returns x-labels, lifts and
the target incidence to create cummulative response curves
def _compute_lift_per_bin(y_true: np.ndarray,
y_pred: np.ndarray,
n_bins: int = 10) -> tuple:
"""Compute lift of the model for a given number of bins, returns x-labels,
lifts and the target incidence to create cumulative response curves

Parameters
----------
y_true : np.ndarray
True binary target data labels
y_pred : np.ndarray
Target scores of the model
n_bins : int, optional
defines the number of bins used to calculate the lift curve for
(by default 10, so deciles)

Returns
-------
Expand All @@ -451,7 +459,7 @@ def _compute_lift_per_decile(y_true: np.ndarray,
lifts = [Evaluator._compute_lift(y_true=y_true,
y_pred=y_pred,
lift_at=perc_lift)
for perc_lift in np.arange(0.1, 1.1, 0.1)]
for perc_lift in np.linspace(1/n_bins, 1, num=n_bins, endpoint=True)]

x_labels = [len(lifts)-x for x in np.arange(0, len(lifts), 1)]

Expand Down
23 changes: 23 additions & 0 deletions tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pytest
import pandas as pd
import numpy as np
from cobra.evaluation import plot_incidence
from cobra.evaluation import Evaluator


def mock_data():
Expand All @@ -11,6 +13,13 @@ def mock_data():
'incidence': [0.047, 0.0434, 0.054, 0.069]}
return pd.DataFrame(d)

def mock_preds(n, seed = 505):
np.random.seed(seed)

y_true = np.random.uniform(size=n)
y_pred = np.random.uniform(size=n)

return y_true, y_pred

class TestEvaluation:

Expand All @@ -19,3 +28,17 @@ def test_plot_incidence(self):
column_order = ['1st-4th', '5th-6th', '7th-8th']
with pytest.raises(Exception):
plot_incidence(data, 'education', column_order)

def test_lift_curve_n_bins(self):
n_bins_test = [5, 10, 15, 35]

y_true, y_pred = mock_preds(50)

n_bins_out = []
for n_bins in n_bins_test:
e = Evaluator(n_bins = n_bins)
out = Evaluator._compute_lift_per_bin(y_true, y_pred, e.n_bins)
lifts = out[1]
n_bins_out.append(len(lifts))

assert n_bins_test == n_bins_out