Skip to content

Commit 41a77e4

Browse files
committed
Refactoring
1 parent 00b434f commit 41a77e4

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

python_bindings/bindings.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,19 @@ inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows,
9191
}
9292

9393

94-
inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_, size_t rows) {
94+
inline std::vector<size_t> get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) {
9595
std::vector<size_t> ids;
9696
if (!ids_.is_none()) {
9797
py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_);
9898
auto ids_numpy = items.request();
9999
// check shapes
100-
bool valid = false;
101-
if ((ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) || (ids_numpy.ndim == 0 && rows == 1)) {
102-
valid = true;
100+
if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) ||
101+
(ids_numpy.ndim == 0 && feature_rows == 1))) {
102+
char msg[1024];
103+
sprintf(msg, "the input label shape %d does not match the input data vector shape %d",
104+
ids_numpy.ndim, feature_rows);
105+
throw std::runtime_error(msg);
103106
}
104-
if (!valid) throw std::runtime_error("wrong dimensionality of the labels");
105107
// extract data
106108
if (ids_numpy.ndim == 1) {
107109
std::vector<size_t> ids1(ids_numpy.shape[0]);

0 commit comments

Comments
 (0)