diff --git a/python/mxnet/gluon/contrib/estimator/estimator.py b/python/mxnet/gluon/contrib/estimator/estimator.py index 54a0b165016e..ab7018f58e1f 100644 --- a/python/mxnet/gluon/contrib/estimator/estimator.py +++ b/python/mxnet/gluon/contrib/estimator/estimator.py @@ -24,7 +24,7 @@ import sys import warnings -from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler +from .event_handler import MetricHandler, ValidationHandler, LoggingHandler, StoppingHandler, GradientUpdateHandler from .event_handler import TrainBegin, EpochBegin, BatchBegin, BatchEnd, EpochEnd, TrainEnd from .event_handler import _check_event_handlers from .utils import _check_metrics, _suggest_metric_for_loss, _check_handler_metric_ref @@ -307,8 +307,6 @@ def fit_batch(self, train_batch, batch_axis=0): for l in loss: l.backward() - self.trainer.step(batch_size) - return data, label, pred, loss def fit(self, train_data, @@ -360,6 +358,7 @@ def fit(self, train_data, self.max_epoch = epochs self.max_batch = batches + self.batch_axis = batch_axis # provide default handlers event_handlers = self._prepare_default_handlers(val_data, event_handlers) @@ -414,6 +413,9 @@ def _prepare_default_handlers(self, val_data, event_handlers): # no need to add to default handler check as StoppingHandler does not use metrics added_default_handlers.append(StoppingHandler(self.max_epoch, self.max_batch)) + if not any(isinstance(handler, GradientUpdateHandler) for handler in event_handlers): + added_default_handlers.append(GradientUpdateHandler()) + if not any(isinstance(handler, MetricHandler) for handler in event_handlers): added_default_handlers.append(MetricHandler(train_metrics=self.train_metrics)) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 3cdc407407c1..64777608bef0 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -31,7 +31,7 @@ __all__ = ['TrainBegin', 'TrainEnd', 'EpochBegin', 'EpochEnd', 'BatchBegin', 'BatchEnd', 'StoppingHandler', 'MetricHandler', 'ValidationHandler', - 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler'] + 'LoggingHandler', 'CheckpointHandler', 'EarlyStoppingHandler', 'GradientUpdateHandler'] class EventHandler(object): @@ -130,13 +130,16 @@ class MetricHandler(EpochBegin, BatchEnd): ---------- train_metrics : List of EvalMetrics Training metrics to be updated at batch end. + priority : scalar + Priority level of the MetricHandler. Priority level is sorted in ascending + order. The lower the number is, the higher priority level the handler is. """ - def __init__(self, train_metrics): + def __init__(self, train_metrics, priority=-1000): self.train_metrics = _check_metrics(train_metrics) # order to be called among all callbacks # metrics need to be calculated before other callbacks can access them - self.priority = -np.Inf + self.priority = priority def epoch_begin(self, estimator, *args, **kwargs): for metric in self.train_metrics: @@ -176,6 +179,10 @@ class ValidationHandler(TrainBegin, BatchEnd, EpochEnd): batch_period : int, default None How often to run validation at batch end, by default :py:class:`ValidationHandler` does not validate at batch end. + priority: scalar, default -1000 + Priority level of the ValidationHandler. Priority level is sorted in + ascending order. The lower the number is, the higher priority level the + handler is. """ def __init__(self, @@ -183,7 +190,8 @@ def __init__(self, eval_fn, val_metrics=None, epoch_period=1, - batch_period=None): + batch_period=None, + priority=-1000): self.val_data = val_data self.eval_fn = eval_fn self.epoch_period = epoch_period @@ -193,7 +201,7 @@ def __init__(self, self.current_epoch = 0 # order to be called among all callbacks # validation metrics need to be calculated before other callbacks can access them - self.priority = -np.Inf + self.priority = priority def train_begin(self, estimator, *args, **kwargs): # reset epoch and batch counter @@ -227,29 +235,27 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - verbose : int, default LOG_PER_EPOCH - Limit the granularity of metrics displayed during training process. - verbose=LOG_PER_EPOCH: display metrics every epoch - verbose=LOG_PER_BATCH: display metrics every batch + log_interval: int or str, default 'epoch' + Logging interval during training. + log_interval='epoch': display metrics every epoch + log_interval=integer k: display metrics every interval of k batches train_metrics : list of EvalMetrics Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. + priority : scalar, default np.Inf + Priority level of the LoggingHandler. Priority level is sorted in + ascending order. The lower the number is, the higher priority level the + handler is. """ - LOG_PER_EPOCH = 1 - LOG_PER_BATCH = 2 - - def __init__(self, verbose=LOG_PER_EPOCH, + def __init__(self, log_interval='epoch', train_metrics=None, - val_metrics=None): + val_metrics=None, + priority=np.Inf): super(LoggingHandler, self).__init__() - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: - raise ValueError("verbose level must be either LOG_PER_EPOCH or " - "LOG_PER_BATCH, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" - % verbose) - self.verbose = verbose + if not isinstance(log_interval, int) and log_interval != 'epoch': + raise ValueError("log_interval must be either an integer or string 'epoch'") self.train_metrics = _check_metrics(train_metrics) self.val_metrics = _check_metrics(val_metrics) self.batch_index = 0 @@ -257,7 +263,8 @@ def __init__(self, verbose=LOG_PER_EPOCH, self.processed_samples = 0 # logging handler need to be called at last to make sure all states are updated # it will also shut down logging at train end - self.priority = np.Inf + self.priority = priority + self.log_interval = log_interval def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() @@ -275,6 +282,7 @@ def train_begin(self, estimator, *args, **kwargs): self.current_epoch = 0 self.batch_index = 0 self.processed_samples = 0 + self.log_interval_time = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start @@ -286,31 +294,34 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if isinstance(self.log_interval, int): self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if isinstance(self.log_interval, int): batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) - msg += 'time/batch: %.3fs ' % batch_time - for metric in self.train_metrics: - # only log current training loss & metric after each batch - name, value = metric.get() - msg += '%s: %.4f, ' % (name, value) - estimator.logger.info(msg.rstrip(', ')) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.train_metrics: + # only log current training loss & metric after each interval + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': self.epoch_start = time.time() estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: @@ -706,3 +717,30 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info('[Epoch %d] EarlyStoppingHanlder: ' 'early stopping due to %s not improving', self.stopped_epoch, self.monitor.get()[0]) + +class GradientUpdateHandler(BatchEnd): + """Gradient Update Handler that apply gradients on network weights + + :py:class:`GradientUpdateHandler` takes the priority level. It updates weight parameters + at the end of each batch + + Parameters + ---------- + priority : scalar, default -2000 + priority level of the gradient update handler. Priority level is sorted in ascending + order. The lower the number is, the higher priority level the handler is. + ---------- + """ + def __init__(self, priority=-2000): + self.priority = priority + + def batch_end(self, estimator, *args, **kwargs): + loss = kwargs['loss'] + batch_size = 0 + if not isinstance(loss, list): + loss = [loss] + if isinstance(loss, list): + for l in loss: + batch_size += l.shape[estimator.batch_axis] + + estimator.trainer.step(batch_size) diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py index 6e2d66cb9d15..d1074c923337 100644 --- a/python/mxnet/metric.py +++ b/python/mxnet/metric.py @@ -590,8 +590,9 @@ def update(self, labels, preds): class _BinaryClassificationMetrics(object): - """Private container class for classification metric statistics. True/false positive and - true/false negative counts are sufficient statistics for various classification metrics. + """Private container class for classification metric statistics. + + True/false positive and true/false negative counts are sufficient statistics for various classification metrics. This class provides the machinery to track those statistics across mini-batches of (label, prediction) pairs. """ @@ -1430,6 +1431,10 @@ class PearsonCorrelation(EvalMetric): label_names : list of str, or None Name of labels that should be used when updating with update_dict. By default include all labels. + average : str, default 'macro' + Strategy to be used for aggregating across mini-batches. + "macro": average the pearsonr scores for each batch. + "micro": compute a single pearsonr score across all batches. Examples -------- @@ -1438,13 +1443,46 @@ class PearsonCorrelation(EvalMetric): >>> pr = mx.metric.PearsonCorrelation() >>> pr.update(labels, predicts) >>> print pr.get() - ('pearson-correlation', 0.42163704544016178) + ('pearsonr', 0.42163704544016178) """ def __init__(self, name='pearsonr', - output_names=None, label_names=None): + output_names=None, label_names=None, average='macro'): + self.average = average super(PearsonCorrelation, self).__init__( name, output_names=output_names, label_names=label_names, has_global_stats=True) + if self.average == 'micro': + self.reset_micro() + + def reset_micro(self): + self._sse_p = 0 + self._mean_p = 0 + self._sse_l = 0 + self._mean_l = 0 + self._pred_nums = 0 + self._label_nums = 0 + self._conv = 0 + + def reset(self): + self.num_inst = 0 + self.sum_metric = 0.0 + self.global_num_inst = 0 + self.global_sum_metric = 0.0 + if self.average == 'micro': + self.reset_micro() + + def update_variance(self, new_values, *aggregate): + #Welford's online algorithm for variance update + count, mean, m_2 = aggregate + count += len(new_values) + delta = new_values - mean + mean += numpy.sum(delta / count) + delta_2 = new_values - mean + m_2 += numpy.sum(delta * delta_2) + return count, mean, m_2 + + def update_cov(self, label, pred): + self._conv = self._conv + numpy.sum((label - self._mean_l) * (pred - self._mean_p)) def update(self, labels, preds): """Updates the internal evaluation result. @@ -1457,17 +1495,34 @@ def update(self, labels, preds): Predicted values. """ labels, preds = check_label_shapes(labels, preds, True) - for label, pred in zip(labels, preds): check_label_shapes(label, pred, False, True) - label = label.asnumpy() - pred = pred.asnumpy() - pearson_corr = numpy.corrcoef(pred.ravel(), label.ravel())[0, 1] - self.sum_metric += pearson_corr - self.global_sum_metric += pearson_corr - self.num_inst += 1 - self.global_num_inst += 1 + label = label.asnumpy().ravel().astype(numpy.float64) + pred = pred.asnumpy().ravel().astype(numpy.float64) + if self.average == 'macro': + pearson_corr = numpy.corrcoef(pred, label)[0, 1] + self.sum_metric += pearson_corr + self.global_sum_metric += pearson_corr + self.num_inst += 1 + self.global_num_inst += 1 + else: + self.global_num_inst += 1 + self.num_inst += 1 + self._label_nums, self._mean_l, self._sse_l = \ + self.update_variance(label, self._label_nums, self._mean_l, self._sse_l) + self.update_cov(label, pred) + self._pred_nums, self._mean_p, self._sse_p = \ + self.update_variance(pred, self._pred_nums, self._mean_p, self._sse_p) + def get(self): + if self.num_inst == 0: + return (self.name, float('nan')) + if self.average == 'macro': + return (self.name, self.sum_metric / self.num_inst) + else: + n = self._label_nums + pearsonr = self._conv / ((n-1) * numpy.sqrt(self._sse_p / (n - 1)) * numpy.sqrt(self._sse_l / (n - 1))) + return (self.name, pearsonr) @register class PCC(EvalMetric): diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index a7ad8e6c6c98..bc1bbe4fd7e4 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -905,6 +905,17 @@ def _basic_indexing_contiguous_flat_begin_end(slc_key, shape): return flat_begin, flat_end + 1 # pylint: enable=invalid-name + @staticmethod + def _drop_int_axes(indexed_shape, int_axes): + """drop the axis of indexed_shape corresponding to int axes""" + bcast_shape = [] + for i, size in enumerate(indexed_shape): + if i not in int_axes: + bcast_shape.append(size) + if not bcast_shape: + bcast_shape = [1] + return tuple(bcast_shape) + def _set_nd_basic_indexing(self, key, value): """This function indexes ``self`` with a tuple of ``slice`` objects only.""" for idx in key: @@ -946,14 +957,10 @@ def _set_nd_basic_indexing(self, key, value): if type(value) == self.__class__: # pylint: disable=unidiomatic-typecheck if value.handle is not self.handle: # Need to do this before `broadcast_to`. - tmp_shape = _shape_for_bcast( - value.shape, target_ndim=self.ndim, new_axes=int_axes - ) - value = value.reshape(tmp_shape) - - if value.shape != self.shape: - value = value.broadcast_to(self.shape) - value.copyto(self) + bcast_shape = self._drop_int_axes(indexed_shape, int_axes) + value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes) + value_nd = value_nd.reshape(indexed_shape) + value_nd.copyto(self) elif isinstance(value, numeric_types): self._full(value) @@ -969,9 +976,10 @@ def _set_nd_basic_indexing(self, key, value): else: # Other array-like - value_nd = self._prepare_value_nd( - value, bcast_shape=self.shape - ) + # drop the axis of indexed_shape corresponding to int axes + bcast_shape = self._drop_int_axes(indexed_shape, int_axes) + value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes) + value_nd = value_nd.reshape(indexed_shape) value_nd.copyto(self) elif isinstance(value, numeric_types): @@ -979,16 +987,8 @@ def _set_nd_basic_indexing(self, key, value): else: # drop the axis of indexed_shape corresponding to int axes - bcast_shape = [] - for i, size in enumerate(indexed_shape): - if i not in int_axes: - bcast_shape.append(size) - if bcast_shape == []: - bcast_shape = [1] - bcast_shape = tuple(bcast_shape) - value_nd = self._prepare_value_nd( - value, bcast_shape=bcast_shape, squeeze_axes=new_axes - ) + bcast_shape = self._drop_int_axes(indexed_shape, int_axes) + value_nd = self._prepare_value_nd(value, bcast_shape=bcast_shape, squeeze_axes=new_axes) value_nd = value_nd.reshape(indexed_shape) self.slice_assign(value_nd, begin, end, step) diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 9375bed5a79b..ba3c33476dac 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -243,8 +243,14 @@ NDArray NDArray::MKLDNNDataReshape(const mxnet::TShape &shape) const { NDArray NDArray::Reshape(const mxnet::TShape &shape) const { CHECK(!is_none()) << "NDArray is not initialized"; - CHECK_GE(shape_.Size(), shape.Size()) - << "NDArray.Reshape: target shape size is larger current shape"; + if (Imperative::Get()->is_np_shape()) { + CHECK_EQ(shape_.Size(), shape.Size()) + << "NDArray.Reshape: target shape must have the same size as " + << "current shape."; + } else { + CHECK_GE(shape_.Size(), shape.Size()) + << "NDArray.Reshape: target shape size is larger than the current shape"; + } NDArray ret = this->Detach(); // If the shape doesn't change, we can just return it now. if (ret.shape_ == shape) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index a9828f40436d..42cc99a46dca 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -119,16 +119,22 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const NumpyTransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace"; + if (req[0] == kNullOp) return; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Transpose only supports kWriteTo, kNullOp and kAddTo"; + mxnet::TShape axes; if (ndim_is_known(param.axes)) { - mxnet::TShape axes = common::CanonicalizeAxes(param.axes); - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + axes = common::CanonicalizeAxes(param.axes); } else { - mxnet::TShape axes(inputs[0].ndim(), -1); + axes = mxnet::TShape(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { axes[i] = axes.ndim() - 1 - i; } - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } + if (req[0] == kAddTo) { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } else { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } } diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 3967cde91d2a..41d8d02c870a 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -24,6 +24,7 @@ */ #include +#include #include "./np_matrix_op-inl.h" #include "../nn/concat-inl.h" @@ -65,8 +66,13 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, mxnet::TShape ret(ndim, -1); if (ndim_is_known(param.axes)) { - CHECK_EQ(ndim, param.axes.ndim()); + CHECK_EQ(ndim, param.axes.ndim()) + << "The number of axes does not match the dimension of the tensor. axes = " + << param.axes << ", input tensor shape = " << shp; mxnet::TShape axes = common::CanonicalizeAxes(param.axes); + std::set axes_set(axes.begin(), axes.end()); + CHECK_EQ(axes_set.size(), axes.ndim()) << "Repeated axis in transpose. param.axes = " + << param.axes; if (ndim_is_known(shp)) { for (int i = 0; i < ndim; ++i) { ret[i] = shp[axes[i]]; @@ -115,9 +121,9 @@ NNVM_REGISTER_OP(_np_transpose) } std::ostringstream os; os << axes; - return MakeNonlossGradNode("transpose", n, ograds, {}, {{"axes", os.str()}}); + return MakeNonlossGradNode("_np_transpose", n, ograds, {}, {{"axes", os.str()}}); } else { - return MakeNonlossGradNode("transpose", n, ograds, {}, + return MakeNonlossGradNode("_np_transpose", n, ograds, {}, std::unordered_map()); } }) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 0fee2a26c0ed..4bd059ae81df 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -269,8 +269,10 @@ struct TransposeParam : public dmlc::Parameter { * \param out output tensor * \param row shape of dim 0 of input * \param col shape of dim 1 of input + * \tparam DType Data type + * \tparam is_addto */ -template +template MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index_t col) { // ensure cache line hits and prevent cache miss for any configuration // L1 cache size to be utilized = 32kb = 2^15 @@ -282,7 +284,7 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // Block-size - 2^5 v 2^5 (32 v 32) with potential 4 pragma for loop unrolled // blocksize * blocksize * num_threads = cache_size / dtype_size // Instead of explicit unroll, let compiler figure out optimal unroll factor - index_t blocksize = 32; + const index_t blocksize = 32; // collapse 2 parallelizes 2 for loops // inner 2 for loops aren't parallelized to prevent cache miss @@ -299,14 +301,25 @@ MSHADOW_XINLINE void Transpose2D(const DType *in, DType *out, index_t row, index // transpose the block for (index_t a = j; (a < blocksize + j) && (a < col); ++a) { for (index_t b = i; (b < blocksize + i) && (b < row); ++b) { - out[a * row + b] = in[b * col + a]; + if (!is_addto) { + out[a * row + b] = in[b * col + a]; + } else { + out[a * row + b] += in[b * col + a]; + } } } } } } -template +inline bool IsIdentityTranspose(const TShape& axes) { + for (dim_t i = 0; i < axes.ndim(); i++) { + if (axes[i] != i) return false; + } + return true; +} + +template void TransposeImpl(RunContext ctx, const TBlob& src, const TBlob& ret, @@ -323,62 +336,79 @@ void TransposeImpl(RunContext ctx, // Example: (0, 2, 3, 1) or (0, 3, 1, 2), but not (0, 2, 1, 3). if (isPseudo2DTranspose(axes)) { MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { - transpose_pseudo2D(ret, src, axes, s); + transpose_pseudo2D(ret, src, axes, s); }); return; } #endif + // Special handle the identity case + if (IsIdentityTranspose(axes)) { + MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { + Tensor in = src.get_with_shape(mshadow::Shape1(src.Size()), s); + Tensor out = ret.get_with_shape(mshadow::Shape1(ret.Size()), s); + if (!is_addto) { + // Use memcpy to accelerate the speed + Copy(out, in, s); + } else { + mxnet_op::Kernel, xpu>::Launch( + s, ret.Size(), out.dptr_, in.dptr_); + } + }); + return; + } + // Handle the general transpose case MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, { switch (axes.ndim()) { - case 0: { - Tensor in = src.get_with_shape(mshadow::Shape1(1), s); - Tensor out = ret.get_with_shape(mshadow::Shape1(1), s); - Copy(out, in, s); - break; - } - case 1: { - Tensor in = src.get(s); - Tensor out = ret.get(s); - Copy(out, in, s); - break; - } case 2: { - mshadow::Tensor in = src.FlatTo2D(s); - mshadow::Tensor out = ret.FlatTo2D(s); - - if (axes[0] == 1 && axes[1] == 0) { - if (ctx.get_ctx().dev_mask() == cpu::kDevMask) { - Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); - } else { - out = in.T(); - } + Tensor in = src.get(s); + Tensor out = ret.get(s); + if (ctx.get_ctx().dev_mask() == cpu::kDevMask) { + Transpose2D(in.dptr_, out.dptr_, in.shape_[0], in.shape_[1]); } else { - Copy(out, in, s); + LOG(FATAL) << "Not Implemented. We should never reach here because the 2D case " + "in GPU has been covered by transpose_pseudo2D." + " Report an issue in Github."; } break; } case 3: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<3>()); + if (!is_addto) { + out = transpose(in, axes.get<3>()); + } else { + out += transpose(in, axes.get<3>()); + } break; } case 4: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<4>()); + if (!is_addto) { + out = transpose(in, axes.get<4>()); + } else { + out += transpose(in, axes.get<4>()); + } break; } case 5: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<5>()); + if (!is_addto) { + out = transpose(in, axes.get<5>()); + } else { + out += transpose(in, axes.get<5>()); + } break; } case 6: { Tensor in = src.get(s); Tensor out = ret.get(s); - out = transpose(in, axes.get<6>()); + if (!is_addto) { + out = transpose(in, axes.get<6>()); + } else { + out += transpose(in, axes.get<6>()); + } break; } default: @@ -399,15 +429,21 @@ void Transpose(const nnvm::NodeAttrs& attrs, return; } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) + << "Transpose only supports kNullOp, kWriteTo and kAddTo"; + mxnet::TShape axes; if (param.axes.ndim() == 0) { - mxnet::TShape axes(inputs[0].ndim(), -1); + axes = mxnet::TShape(inputs[0].ndim(), -1); for (int i = 0; i < axes.ndim(); ++i) { axes[i] = axes.ndim() - 1 - i; } - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } else { - TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], param.axes); + axes = common::CanonicalizeAxes(param.axes); + } + if (req[0] == kAddTo) { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + } else { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); } } diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index b09b332765f8..1e69f72615fe 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -283,11 +283,12 @@ static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs, return; } const TransposeParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(req[0], kWriteTo) << "Transpose does not support kWriteInplace and kAddTo"; + CHECK(req[0] == kWriteTo || req[0] == kAddTo) << + "Transpose only supports kNullOp, kWriteTo and kAddTo"; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); - if (SupportMKLDNNTranspose(param, inputs[0])) { + if (SupportMKLDNNTranspose(param, inputs[0]) && req[0] == kWriteTo) { MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]); return; } diff --git a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh index 5b7cf04daef4..b3ca9fbfa0c9 100644 --- a/src/operator/tensor/pseudo2DTranspose_op-inl.cuh +++ b/src/operator/tensor/pseudo2DTranspose_op-inl.cuh @@ -39,22 +39,31 @@ namespace mxnet { namespace op { namespace cuda { - -template +/*! + * \brief The `transpose_pseudo2D` based on chosen vectorized types. It transposes an array of + * shape (k, m, n) to (k, n, m) + * \param out Pointer to output memory. + * \param inp Pointer to input memory. + * \param m First of tensor dimensions. + * \param n Second of tensor dimensions. + * \param nIterY The number of iterations in the y-dim of the thread to cover all rows. (1-->m) + * \param nIterZ The number of iterations in the z-dim of the thread to cover all rows. (1-->k) + * \tparam DType Data type + * \tparam CType The type to load the data. + * \tparam is_addto Whether to perform out += transpose(data) or out = transpose(data) + */ +template __global__ void transpose_pseudo2D(DType* out, DType* inp, const index_t m, const index_t n, const index_t nIterY, const index_t nIterZ) { - const index_t TSR = sizeof(CType)/sizeof(DType); // TypeSizeRatio + // Calculate the TypeSizeRatio + const index_t TSR = sizeof(CType) / sizeof(DType) > 0 ? sizeof(CType) / sizeof(DType) : 1; const index_t chunked_n = n/TSR; const index_t chunked_m = m/TSR; - union transp_t { - CType valChunk; - DType values[TSR]; - }; - - __shared__ DType d_shm[1024*TSR*TSR]; - CType* c_shm = reinterpret_cast(d_shm); + extern __shared__ char buf[]; + DType* d_shm = reinterpret_cast(buf); + CType* c_shm = reinterpret_cast(buf); CType* cInp = reinterpret_cast(inp); CType* cOut = reinterpret_cast(out); @@ -78,23 +87,34 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, } __syncthreads(); - // read from shared to registers - transp_t tmp[TSR]; + // read from shared to local registers + CType tmp[TSR]; #pragma unroll for (index_t i = 0; i < TSR; i++) { + DType* tmp_dptr = reinterpret_cast(&tmp[i]); #pragma unroll for (int j = 0; j < TSR; j++) { index_t shmIdx = (TSR*threadIdx.y + j)*blockDim.x*TSR + TSR*threadIdx.x + i; - tmp[i].values[j] = d_shm[shmIdx]; + tmp_dptr[j] = d_shm[shmIdx]; } } __syncthreads(); // write back to global output - offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + blockIdx_y*blockDim.y; + offset = blockIdx_z*m*chunked_n + blockIdx.x*blockDim.x*TSR*chunked_m + + blockIdx_y*blockDim.y; #pragma unroll for (index_t i = 0; i < TSR; i++) { - cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i].valChunk; + if (is_addto) { + DType* tmp_dptr = reinterpret_cast(&tmp[i]); + #pragma unroll + for (int j = 0; j < TSR; j++) { + out[TSR * (offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y) + j] + += tmp_dptr[j]; + } + } else { + cOut[offset + (TSR*threadIdx.x + i)*chunked_m + threadIdx.y] = tmp[i]; + } } } } @@ -107,7 +127,6 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, /*! * \brief Calls proper version of kernel `transpose_pseudo2D` * basing on chosen type sizes. - * \param dTypeSize Size of data type. * \param cTypeSize Size of type that should be use to copy. * \param grid Grid dimensions for the kernel. * \param block Block dimensions for the kernel. @@ -116,92 +135,39 @@ __global__ void transpose_pseudo2D(DType* out, DType* inp, * \param inp Pointer to input memory. * \param m First of tensor dimensions. * \param n Second of tensor dimensions. + * \tparam DType Data type + * \tparam is_addto Whether to trigger add the transpose result to the output tensor. */ -inline void call_transpose_pseudo2D(index_t dTypeSize, index_t cTypeSize, - dim3 grid, dim3 block, cudaStream_t stream, - void* out, void* inp, const index_t m, const index_t n, - const index_t nIterY, const index_t nIterZ) { - switch (dTypeSize) { - case (1): { - uint8_t* d_outPtr = reinterpret_cast(out); - uint8_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (1): - cuda::transpose_pseudo2D<<>> +template +inline void call_transpose_pseudo2D(index_t cTypeSize, + dim3 grid, dim3 block, cudaStream_t stream, + DType* d_outPtr, DType* d_inpPtr, + const index_t m, const index_t n, + const index_t nIterY, const index_t nIterZ) { + const int nshared = 1024 * cTypeSize / sizeof(DType) * cTypeSize; + switch (cTypeSize) { + case (1): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (2): - cuda::transpose_pseudo2D<<>> + case (2): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (4): - cuda::transpose_pseudo2D<<>> + case (4): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (8): - // case guarded against in function getBestCopyTypeSize - LOG(FATAL) << "cuda::transpose_pseudo2D would take too much shared memory"; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (2): { - uint16_t* d_outPtr = reinterpret_cast(out); - uint16_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (2): - cuda::transpose_pseudo2D<<>> + case (8): + cuda::transpose_pseudo2D<<>> (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); break; - case (4): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (4): { - uint32_t* d_outPtr = reinterpret_cast(out); - uint32_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (4): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - case (8): { - uint64_t* d_outPtr = reinterpret_cast(out); - uint64_t* d_inpPtr = reinterpret_cast(inp); - switch (cTypeSize) { - case (8): - cuda::transpose_pseudo2D<<>> - (d_outPtr, d_inpPtr, m, n, nIterY, nIterZ); - break; - default: - LOG(FATAL) << "Unsupported type combination"; - } - break; - } - default: - LOG(FATAL) << "Unsupported type combination"; + default: + LOG(FATAL) << "Unsupported type combination. " << "Copy type size = " << cTypeSize; } auto cuErr = cudaPeekAtLastError(); - CHECK_EQ(cuErr, cudaSuccess) << "Transpose kernel failure: " << cudaGetErrorString(cuErr) << ". " + CHECK_EQ(cuErr, cudaSuccess) << "TransposePseudo2D kernel failure: " + << cudaGetErrorString(cuErr) << ". " << "block: (" << block.x << "," << block.y << "," << block.z << ")" << " grid: (" << grid.x << "," << grid.y << "," << grid.z << ")"; } @@ -225,7 +191,6 @@ inline bool isPseudo2DTranspose(const TShape& params) { return n_swpDims == 2; } - struct pseudo2DSizes { index_t leadDimS; index_t M; @@ -306,15 +271,14 @@ inline std::pair calculateKernelParams(pseudo2DSizes sizes, const in * \param outBlob Tensor blob to store result. * \param inpBlob Tensor blob with input data. * \param params Parameters (axes) of the transpose. + * \param is_addto Whether to add the transpose result to the outBlob * \param s Pointer to GPU stream. */ -template +template void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob, const TShape& params, mshadow::Stream* s) { const TShape& shape = inpBlob.shape_; CHECK_EQ(shape.ndim(), params.ndim()); - auto ndim = params.ndim(); - auto sizes = getPackedTransposeDimensions(shape, params); index_t cTypeSize = getBestCopyTypeSize(sizeof(DType), sizes.M, sizes.N); @@ -337,8 +301,10 @@ void transpose_pseudo2D(const TBlob& outBlob, const TBlob& inpBlob, } cudaStream_t stream = mshadow::Stream::GetStream(s); - call_transpose_pseudo2D(sizeof(DType), cTypeSize, grid, block, stream, - outBlob.dptr_, inpBlob.dptr_, sizes.M, sizes.N, nIterY, nIterZ); + call_transpose_pseudo2D + (cTypeSize, grid, block, stream, + outBlob.dptr(), inpBlob.dptr(), + sizes.M, sizes.N, nIterY, nIterZ); } } // namespace op diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index cf913a6161c0..21f949a0bba6 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -367,6 +367,7 @@ def test_default_handlers(): val_metrics = est.val_metrics early_stopping = EarlyStoppingHandler(monitor=val_metrics[0]) handlers = est._prepare_default_handlers(val_data=None, event_handlers=[early_stopping]) - assert len(handlers) == 4 - assert isinstance(handlers[0], MetricHandler) - assert isinstance(handlers[3], LoggingHandler) + assert len(handlers) == 5 + assert isinstance(handlers[0], GradientUpdateHandler) + assert isinstance(handlers[1], MetricHandler) + assert isinstance(handlers[4], LoggingHandler) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 17c75813d516..658fb88f47e5 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -17,13 +17,19 @@ import os import logging +import sys +import re import mxnet as mx from common import TemporaryDirectory from mxnet import nd from mxnet.gluon import nn, loss from mxnet.gluon.contrib.estimator import estimator, event_handler - +from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler +try: + from StringIO import StringIO +except ImportError: + from io import StringIO def _get_test_network(net=nn.Sequential()): net.add(nn.Dense(128, activation='relu', flatten=False), @@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()): return net -def _get_test_data(): - data = nd.ones((32, 100)) - label = nd.zeros((32, 1)) +def _get_test_data(in_size=32): + data = nd.ones((in_size, 100)) + label = nd.zeros((in_size, 1)) data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) return mx.gluon.data.DataLoader(data_arr, batch_size=8) @@ -200,3 +206,61 @@ def epoch_end(self, estimator, *args, **kwargs): est.fit(test_data, event_handlers=[custom_handler], epochs=10) assert custom_handler.num_batch == 5 * 4 assert custom_handler.num_epoch == 5 + +def test_logging_interval(): + ''' test different options for logging handler ''' + ''' test case #1: log interval is 1 ''' + batch_size = 8 + data_size = 100 + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + log_interval = 1 + net = _get_test_network() + dataloader = _get_test_data(in_size=data_size) + num_epochs = 1 + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc) + + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + ''' test case #2: log interval is 5 ''' + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + acc = mx.metric.Accuracy() + log_interval = 5 + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc) + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py index 0ae8aeaa697f..a1e5128d8ac6 100644 --- a/tests/python/unittest/test_metric.py +++ b/tests/python/unittest/test_metric.py @@ -17,6 +17,7 @@ import mxnet as mx import numpy as np +import scipy import json import math from common import with_seed @@ -263,13 +264,40 @@ def test_perplexity(): assert perplexity == perplexity_expected def test_pearsonr(): - pred = mx.nd.array([[0.7, 0.3], [0.1, 0.9], [1., 0]]) - label = mx.nd.array([[0, 1], [1, 0], [1, 0]]) - pearsonr_expected = np.corrcoef(pred.asnumpy().ravel(), label.asnumpy().ravel())[0, 1] - metric = mx.metric.create('pearsonr') - metric.update([label], [pred]) - _, pearsonr = metric.get() - assert pearsonr == pearsonr_expected + pred1 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]]) + label1 = mx.nd.array([[1, 0], [0, 1], [0, 1]]) + pearsonr_expected_np = np.corrcoef(pred1.asnumpy().ravel(), label1.asnumpy().ravel())[0, 1] + pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred1.asnumpy().ravel(), label1.asnumpy().ravel()) + macro_pr = mx.metric.create('pearsonr', average='macro') + micro_pr = mx.metric.create('pearsonr', average='micro') + + assert np.isnan(macro_pr.get()[1]) + assert np.isnan(micro_pr.get()[1]) + + macro_pr.update([label1], [pred1]) + micro_pr.update([label1], [pred1]) + + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy) + + pred2 = mx.nd.array([[1, 2], [3, 2], [4, 6]]) + label2 = mx.nd.array([[1, 0], [0, 1], [0, 1]]) + # Note that pred12 = pred1 + pred2; label12 = label1 + label2 + pred12 = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6],[1, 2], [3, 2], [4, 6]]) + label12 = mx.nd.array([[1, 0], [0, 1], [0, 1], [1, 0], [0, 1], [0, 1]]) + + pearsonr_expected_np = np.corrcoef(pred12.asnumpy().ravel(), label12.asnumpy().ravel())[0, 1] + pearsonr_expected_scipy, _ = scipy.stats.pearsonr(pred12.asnumpy().ravel(), label12.asnumpy().ravel()) + + macro_pr.reset() + micro_pr.update([label2], [pred2]) + macro_pr.update([label12], [pred12]) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(macro_pr.get()[1], pearsonr_expected_scipy) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_np) + np.testing.assert_almost_equal(micro_pr.get()[1], pearsonr_expected_scipy) def cm_batch(cm): # generate a batch yielding a given confusion matrix diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 4c6d9f7b8ef2..d097799da286 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -154,6 +154,22 @@ def test_ndarray_setitem(): assert x.shape == trivial_shape assert same(x.asnumpy(), x_np) + # test https://github.com/apache/incubator-mxnet/issues/16647 + dst = mx.nd.zeros((1, 3, 1)) # destination array + src = [1, 2, 3] + dst[0, :len(src), 0] = src + assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape)) + + dst = mx.nd.zeros((1, 3, 1)) # destination array + src = [1, 2, 3] + dst[0, :len(src), 0] = mx.nd.array(src) + assert same(dst.asnumpy(), np.array([1, 2, 3], dtype=dst.dtype).reshape(dst.shape)) + + dst = mx.nd.zeros((1, 3, 1)) # destination array + src = [1, 2] + dst[0, :len(src), 0] = src + assert same(dst.asnumpy(), np.array([1, 2, 0], dtype=dst.dtype).reshape(dst.shape)) + @with_seed() def test_ndarray_elementwise(): diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 9f4e62cac50c..0bd620ee7f78 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -696,26 +696,37 @@ def _is_basic_index(index): np_indexed_array = _np.random.randint(low=-10000, high=0, size=indexed_array_shape) # test value is a native numpy array without broadcast assert_same(np_array, np_index, mx_array, index, np_indexed_array) + # test value is a list without broadcast + assert_same(np_array, np_index, mx_array, index, np_indexed_array.tolist()) # test value is a mxnet numpy array without broadcast assert_same(np_array, np_index, mx_array, index, np.array(np_indexed_array)) # test value is an numeric_type assert_same(np_array, np_index, mx_array, index, _np.random.randint(low=-10000, high=0)) - if len(indexed_array_shape) > 1: - np_value = _np.random.randint(low=-10000, high=0, size=(indexed_array_shape[-1],)) - # test mxnet ndarray with broadcast - assert_same(np_array, np_index, mx_array, index, np.array(np_value)) - # test native numpy array with broadcast - assert_same(np_array, np_index, mx_array, index, np_value) - - # test value shape are expanded to be longer than index array's shape - # this is currently only supported in basic indexing - if _is_basic_index(index): - expanded_value_shape = (1, 1, 1) + np_value.shape - assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape))) - assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape)) - # test list with broadcast - assert_same(np_array, np_index, mx_array, index, - [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1]) + + np_value = _np.random.randint(low=-10000, high=0, + size=(indexed_array_shape[-1],) if len(indexed_array_shape) > 0 else ()) + # test mxnet ndarray with broadcast + assert_same(np_array, np_index, mx_array, index, np.array(np_value)) + # test native numpy array with broadcast + assert_same(np_array, np_index, mx_array, index, np_value) + # test python list with broadcast + assert_same(np_array, np_index, mx_array, index, np_value.tolist()) + + # test value shape are expanded to be longer than index array's shape + # this is currently only supported in basic indexing + if _is_basic_index(index): + expanded_value_shape = (1, 1) + np_value.shape + assert_same(np_array, np_index, mx_array, index, np.array(np_value.reshape(expanded_value_shape))) + assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape)) + if len(expanded_value_shape) <= np_array[index].ndim: + # NumPy does not allow value.ndim > np_array[index].ndim when value is a python list. + # It may be a bug of NumPy. + assert_same(np_array, np_index, mx_array, index, np_value.reshape(expanded_value_shape).tolist()) + + # test list with broadcast + assert_same(np_array, np_index, mx_array, index, + [_np.random.randint(low=-10000, high=0)] * indexed_array_shape[-1] if len(indexed_array_shape) > 0 + else _np.random.randint(low=-10000, high=0)) def test_getitem_autograd(np_array, index): """ @@ -905,6 +916,9 @@ def test_setitem_autograd(np_array, index): range(4), range(3, 0, -1), (range(4,), [1]), + (1, 1, slice(None), 1), + (1, 1, slice(None, 3), 1), + (1, 1, slice(None, 8, 3), 1), ] for index in index_list: test_getitem(np_array, index) @@ -925,8 +939,8 @@ def test_setitem_autograd(np_array, index): # test zero-size tensors get and setitem shapes_indices = [ - ((0), [slice(None, None, None)]), - ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]), + ((0), [slice(None, None, None)]), + ((3, 0), [2, (slice(None, None, None)), (slice(None, None, None), None)]), ] for shape, indices in shapes_indices: np_array = _np.zeros(shape) @@ -1198,11 +1212,14 @@ def test_np_ndarray_pickle(): a = np.random.uniform(size=(4, 5)) a_copy = a.copy() import pickle - with open("np_ndarray_pickle_test_file", 'wb') as f: - pickle.dump(a_copy, f) - with open("np_ndarray_pickle_test_file", 'rb') as f: - a_load = pickle.load(f) - same(a.asnumpy(), a_load.asnumpy()) + + with TemporaryDirectory() as work_dir: + fname = os.path.join(work_dir, 'np_ndarray_pickle_test_file') + with open(fname, 'wb') as f: + pickle.dump(a_copy, f) + with open(fname, 'rb') as f: + a_load = pickle.load(f) + same(a.asnumpy(), a_load.asnumpy()) if __name__ == '__main__': diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 5ec9944b33a4..1ff1b6139cce 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1302,7 +1302,9 @@ def np_transpose_grad(out_shape, dtype, axes=None): if axes is None or axes == (): return _np.transpose(ograd, axes) np_axes = _np.array(list(axes)) - return _np.transpose(ograd, tuple(list(_np.argsort(np_axes)))) + transpose_axes = _np.zeros_like(np_axes) + transpose_axes[np_axes] = _np.arange(len(np_axes)) + return _np.transpose(ograd, tuple(list(transpose_axes))) class TestTranspose(HybridBlock): def __init__(self, axes=None): @@ -1311,45 +1313,57 @@ def __init__(self, axes=None): def hybrid_forward(self, F, a): return F.np.transpose(a, self.axes) + test_workloads = [[(), [(), None]], + [(2,), [(0,), None]], + [(0, 2), [(0, 1), (1, 0)]], + [(5, 10), [(0, 1), (1, 0), None]], + [(8, 2, 3), [(2, 0, 1), (0, 2, 1), (0, 1, 2), (2, 1, 0), (-1, 1, 0), None]], + [(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]], + [(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]], + [(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]], + [(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]] for hybridize in [True, False]: - for dtype in [_np.int32, _np.float32]: - for ndim in range(7): - shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True) - axeses = [None] - if ndim == 0: - axeses += [()] - else: - axes = [i for i in range(ndim)] - axeses.append(tuple(axes)) - random.shuffle(axes) - axeses.append(tuple(axes)) - axeses.append([i - len(axes) for i in axes]) - for axes in axeses: - test_trans = TestTranspose(axes) - if hybridize: - test_trans.hybridize() - x = rand_ndarray(shape).as_np_ndarray() - x = x.astype(dtype) - x.attach_grad() - np_out = _np.transpose(x.asnumpy(), axes) - with mx.autograd.record(): - mx_out = test_trans(x) - assert mx_out.shape == np_out.shape - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) - mx_out.backward() - np_backward = np_transpose_grad(np_out.shape, dtype, axes) - assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - - mx_out = x.transpose(axes) - np_out = x.asnumpy().transpose(axes) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + for dtype in [_np.float32, _np.float16, _np.int32]: + for data_shape, axes_workload in test_workloads: + for axes in axes_workload: + for grad_req in ['write', 'add']: + test_trans = TestTranspose(axes) + if hybridize: + test_trans.hybridize() + x = np.random.normal(0, 1, data_shape).astype(dtype) + x = x.astype(dtype) + x.attach_grad(grad_req=grad_req) + if grad_req == 'add': + x.grad[()] = np.random.normal(0, 1, x.grad.shape).astype(x.grad.dtype) + x_grad_np = x.grad.asnumpy() + np_out = _np.transpose(x.asnumpy(), axes) + with mx.autograd.record(): + mx_out = test_trans(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + mx_out.backward() + np_backward = np_transpose_grad(np_out.shape, dtype, axes) + if grad_req == 'add': + assert_almost_equal(x.grad.asnumpy(), np_backward + x_grad_np, + rtol=1e-3, atol=1e-5, use_broadcast=False) + else: + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5, use_broadcast=False) - if isinstance(axes, (list, tuple)): - mx_out = x.transpose(*axes) - np_out = x.asnumpy().transpose(*axes) + mx_out = x.transpose(axes) + np_out = x.asnumpy().transpose(axes) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + if isinstance(axes, (list, tuple)): + mx_out = x.transpose(*axes) + np_out = x.asnumpy().transpose(*axes) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False) + # Test for error raising + dat = np.random.normal(0, 1, (3, 4, 5), dtype=np.float32) + assert_raises(MXNetError, lambda: dat.transpose((0, 0, 1))) + assert_raises(MXNetError, lambda: dat.transpose((0, 1, 3))) + + @with_seed() @use_np