diff --git a/CHANGELOG.md b/CHANGELOG.md index b9e1944cf..5788dd187 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ ### Changed - Speed up `constant * Expr` via C-level API - Speed up `Term.__eq__` via the C-level API +- Speed up `Expr.__add__` and `Expr.__iadd__` via the C-level API ### Removed - Removed outdated warning about Make build system incompatibility - Removed `Term.ptrtuple` to optimize `Term` memory usage diff --git a/src/pyscipopt/expr.pxi b/src/pyscipopt/expr.pxi index b26b87997..092155666 100644 --- a/src/pyscipopt/expr.pxi +++ b/src/pyscipopt/expr.pxi @@ -302,42 +302,33 @@ cdef class Expr(ExprLike): return iter(self.terms) def __add__(self, other): - left = self - right = other - terms = left.terms.copy() - - if isinstance(right, Expr): - # merge the terms by component-wise addition - for v,c in right.terms.items(): - terms[v] = terms.get(v, 0.0) + c - elif _is_number(right): - c = float(right) - terms[CONST] = terms.get(CONST, 0.0) + c - elif isinstance(right, GenExpr): - return buildGenExprObj(left) + right - elif isinstance(right, np.ndarray): - return right + left + if _is_number(other): + terms = self.terms.copy() + terms[CONST] = terms.get(CONST, 0.0) + other + return Expr(terms) + elif isinstance(other, Expr): + return Expr(_to_dict(self, other, copy=True)) + elif isinstance(other, GenExpr): + return buildGenExprObj(self) + other + elif isinstance(other, np.ndarray): + return other + self else: - raise TypeError(f"Unsupported type {type(right)}") - - return Expr(terms) + raise TypeError(f"unsupported type {type(other).__name__!r}") def __iadd__(self, other): - if isinstance(other, Expr): - for v,c in other.terms.items(): - self.terms[v] = self.terms.get(v, 0.0) + c - elif _is_number(other): - c = float(other) - self.terms[CONST] = self.terms.get(CONST, 0.0) + c + if _is_number(other): + self.terms[CONST] = self.terms.get(CONST, 0.0) + other + return self + elif isinstance(other, Expr): + _to_dict(self, other, copy=False) + return self elif isinstance(other, GenExpr): # is no longer in place, might affect performance? # can't do `self = buildGenExprObj(self) + other` since I get # TypeError: Cannot convert pyscipopt.scip.SumExpr to pyscipopt.scip.Expr return buildGenExprObj(self) + other else: - raise TypeError(f"Unsupported type {type(other)}") - - return self + raise TypeError(f"unsupported type {type(other).__name__!r}") def __mul__(self, other): if isinstance(other, np.ndarray): @@ -1031,6 +1022,28 @@ cdef inline object _wrap_ufunc(object x, object ufunc): return res.view(MatrixGenExpr) if isinstance(res, np.ndarray) else res return ufunc(_to_const(x)) +cdef dict _to_dict(Expr expr, Expr other, bool copy = True): + cdef dict children = expr.terms.copy() if copy else expr.terms + cdef Py_ssize_t pos = 0 + cdef PyObject* k_ptr = NULL + cdef PyObject* v_ptr = NULL + cdef PyObject* old_v_ptr = NULL + cdef double other_v + cdef object k_obj + + while PyDict_Next(other.terms, &pos, &k_ptr, &v_ptr): + if (other_v := (v_ptr)) == 0: + continue + + k_obj = k_ptr + old_v_ptr = PyDict_GetItem(children, k_obj) + if old_v_ptr != NULL: + children[k_obj] = (old_v_ptr) + other_v + else: + children[k_obj] = other_v + + return children + def expr_to_nodes(expr): '''transforms tree to an array of nodes. each node is an operator and the position of the diff --git a/tests/test_expr.py b/tests/test_expr.py index 8f5802d63..433120404 100644 --- a/tests/test_expr.py +++ b/tests/test_expr.py @@ -328,3 +328,28 @@ def test_term_eq(): assert t3 != t4 # same length, but different term assert t1 != t3 # different length assert t1 != "not a term" # different type + + +def test_Expr_add_Expr(): + m = Model() + x = m.addVar(name="x") + y = m.addVar(name="y") + + e1 = -x + 1 + e2 = y - 1 + e3 = e1 + e2 + assert str(e1) == "Expr({Term(x): -1.0, Term(): 1.0})" + assert str(e2) == "Expr({Term(y): 1.0, Term(): -1.0})" + assert str(e3) == "Expr({Term(x): -1.0, Term(): 0.0, Term(y): 1.0})" + + +def test_Expr_iadd_Expr(): + m = Model() + x = m.addVar(name="x") + y = m.addVar(name="y") + + e1 = -x + 1 + e2 = y - 1 + e1 += e2 + assert str(e1) == "Expr({Term(x): -1.0, Term(): 0.0, Term(y): 1.0})" + assert str(e2) == "Expr({Term(y): 1.0, Term(): -1.0})"