Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys
import warnings

from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler
from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler
from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd
from .event_handler import _check_event_handlers
from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref
Expand Down Expand Up @@ -307,8 +307,6 @@ def fit_batch(self, train_batch, batch_axis=0):
for l in loss:
l.backward()

self.trainer.step(batch_size)

return data, label, pred, loss

def fit(self, train_data,
Expand Down Expand Up @@ -360,6 +358,7 @@ def fit(self, train_data,

self.max_epoch = epochs
self.max_batch = batches
self.batch_axis = batch_axis

# provide default handlers
event_handlers = self._prepare_default_handlers(val_data, event_handlers)
Expand Down Expand Up @@ -414,6 +413,9 @@ def _prepare_default_handlers(self, val_data, event_handlers):
# no need to add to default handler check as StoppingHandler does not use metrics
added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch))

if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers):
added_default_handlers.append(GradientUpdateHandler())

if not any(isinstance(handler, MetricHandler) for handler in event_handlers):
added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics))

Expand Down
100 changes: 69 additions & 31 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

__all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd',
'StoppingHandler', 'MetricHandler', 'ValidationHandler',
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler']
'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler']


class EventHandler(object):
Expand Down Expand Up @@ -130,13 +130,16 @@ class MetricHandler(EpochBegin, BatchEnd):
----------
train_metrics : List of EvalMetrics
Training metrics to be updated at batch end.
priority : scalar
Priority level of the MetricHandler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
"""

def __init__(self, train_metrics):
def __init__(self, train_metrics, priority=-1000):
self.train_metrics = _check_metrics(train_metrics)
# order to be called among all callbacks
# metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
self.priority = priority

def epoch_begin(self, estimator, *args, **kwargs):
for metric in self.train_metrics:
Expand Down Expand Up @@ -176,14 +179,19 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd):
batch_period : int, default None
How often to run validation at batch end, by default
:py:class:`ValidationHandler` does not validate at batch end.
priority: scalar, default -1000
Priority level of the ValidationHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self,
val_data,
eval_fn,
val_metrics=None,
epoch_period=1,
batch_period=None):
batch_period=None,
priority=-1000):
self.val_data = val_data
self.eval_fn = eval_fn
self.epoch_period = epoch_period
Expand All @@ -193,7 +201,7 @@ def __init__(self,
self.current_epoch = 0
# order to be called among all callbacks
# validation metrics need to be calculated before other callbacks can access them
self.priority = -np.Inf
self.priority = priority

def train_begin(self, estimator, *args, **kwargs):
# reset epoch and batch counter
Expand Down Expand Up @@ -227,37 +235,36 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat

Parameters
----------
verbose : int, default LOG_PER_EPOCH
Limit the granularity of metrics displayed during training process.
verbose=LOG_PER_EPOCH: display metrics every epoch
verbose=LOG_PER_BATCH: display metrics every batch
log_interval: int or str, default 'epoch'
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
priority : scalar, default np.Inf
Priority level of the LoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

LOG_PER_EPOCH = 1
LOG_PER_BATCH = 2

def __init__(self, verbose=LOG_PER_EPOCH,
def __init__(self, log_interval='epoch',
train_metrics=None,
val_metrics=None):
val_metrics=None,
priority=np.Inf):
super(LoggingHandler, self).__init__()
if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
"LOG_PER_BATCH, received %s. "
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
% verbose)
self.verbose = verbose
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")
self.train_metrics = _check_metrics(train_metrics)
self.val_metrics = _check_metrics(val_metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = np.Inf
self.priority = priority
self.log_interval = log_interval

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
Expand All @@ -275,6 +282,7 @@ def train_begin(self, estimator, *args, **kwargs):
self.current_epoch = 0
self.batch_index = 0
self.processed_samples = 0
self.log_interval_time = 0

def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
Expand All @@ -286,31 +294,34 @@ def train_end(self, estimator, *args, **kwargs):
estimator.logger.info(msg.rstrip(', '))

def batch_begin(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if isinstance(self.log_interval, int):
self.batch_start = time.time()

def batch_end(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if isinstance(self.log_interval, int):
batch_time = time.time() - self.batch_start
msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
self.processed_samples += kwargs['batch'][0].shape[0]
msg += '[Samples %s] ' % (self.processed_samples)
msg += 'time/batch: %.3fs ' % batch_time
for metric in self.train_metrics:
# only log current training loss & metric after each batch
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
self.log_interval_time += batch_time
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for metric in self.train_metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1

def epoch_begin(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)

def epoch_end(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
Expand Down Expand Up @@ -706,3 +717,30 @@ def train_end(self, estimator, *args, **kwargs):
estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: '
'early stopping due to %s not improving',
self.stopped_epoch, self.monitor.get()[0])

class GradientUpdateHandler(BatchEnd):
"""Gradient Update Handler that apply gradients on network weights

:py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters
at the end of each batch

Parameters
----------
priority : scalar, default -2000
priority level of the gradient update handler. Priority level is sorted in ascending
order. The lower the number is, the higher priority level the handler is.
----------
"""
def __init__(self, priority=-2000):
self.priority = priority

def batch_end(self, estimator, *args, **kwargs):
loss = kwargs['loss']
batch_size = 0
if not isinstance(loss, list):
loss = [loss]
if isinstance(loss, list):
for l in loss:
batch_size += l.shape[estimator.batch_axis]

estimator.trainer.step(batch_size)
79 changes: 67 additions & 12 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,9 @@ def update(self, labels, preds):


class _BinaryClassificationMetrics(object):
"""Private container class for classification metric statistics. True/false positive and
true/false negative counts are sufficient statistics for various classification metrics.
"""Private container class for classification metric statistics.

True/false positive and true/false negative counts are sufficient statistics for various classification metrics.
This class provides the machinery to track those statistics across mini-batches of
(label, prediction) pairs.
"""
Expand Down Expand Up @@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric):
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
average : str, default 'macro'
Strategy to be used for aggregating across mini-batches.
"macro": average the pearsonr scores for each batch.
"micro": compute a single pearsonr score across all batches.

Examples
--------
Expand All @@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric):
>>> pr = mx.metric.PearsonCorrelation()
>>> pr.update(labels, predicts)
>>> print pr.get()
('pearson-correlation', 0.42163704544016178)
('pearsonr', 0.42163704544016178)
"""
def __init__(self, name='pearsonr',
output_names=None, label_names=None):
output_names=None, label_names=None, average='macro'):
self.average = average
super(PearsonCorrelation, self).__init__(
name, output_names=output_names, label_names=label_names,
has_global_stats=True)
if self.average == 'micro':
self.reset_micro()

def reset_micro(self):
self._sse_p = 0
self._mean_p = 0
self._sse_l = 0
self._mean_l = 0
self._pred_nums = 0
self._label_nums = 0
self._conv = 0

def reset(self):
self.num_inst = 0
self.sum_metric = 0.0
self.global_num_inst = 0
self.global_sum_metric = 0.0
if self.average == 'micro':
self.reset_micro()

def update_variance(self, new_values, *aggregate):
#Welford's online algorithm for variance update
count, mean, m_2 = aggregate
count += len(new_values)
delta = new_values - mean
mean += numpy.sum(delta / count)
delta_2 = new_values - mean
m_2 += numpy.sum(delta * delta_2)
return count, mean, m_2

def update_cov(self, label, pred):
self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p))

def update(self, labels, preds):
"""Updates the internal evaluation result.
Expand All @@ -1457,17 +1495,34 @@ def update(self, labels, preds):
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
check_label_shapes(label, pred, False, True)
label = label.asnumpy()
pred = pred.asnumpy()
pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
label = label.asnumpy().ravel().astype(numpy.float64)
pred = pred.asnumpy().ravel().astype(numpy.float64)
if self.average == 'macro':
pearson_corr = numpy.corrcoef(pred, label)[0, 1]
self.sum_metric += pearson_corr
self.global_sum_metric += pearson_corr
self.num_inst += 1
self.global_num_inst += 1
else:
self.global_num_inst += 1
self.num_inst += 1
self._label_nums, self._mean_l, self._sse_l = \
self.update_variance(label, self._label_nums, self._mean_l, self._sse_l)
self.update_cov(label, pred)
self._pred_nums, self._mean_p, self._sse_p = \
self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p)

def get(self):
if self.num_inst == 0:
return (self.name, float('nan'))
if self.average == 'macro':
return (self.name, self.sum_metric / self.num_inst)
else:
n = self._label_nums
pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1)))
return (self.name, pearsonr)

@register
class PCC(EvalMetric):
Expand Down
Loading