Skip to content

Commit ef78d29

Browse files
API: Model.getSolVal supports MatrixExpr (#1183)
* Support MatrixExpr in solution retrieval Allow MatrixExpr to be queried from Solution and Model methods: add MatrixExpr to type annotations for Solution.__getitem__ and Model.getSolVal, update return types to include np.ndarray, and improve the runtime type check in __getitem__ with a clear TypeError. Simplify best-solution value retrieval by delegating to Solution.__getitem__ (removing the manual loop over MatrixExpr), and update docstrings to reflect the new accepted types and return values. * Simplify stage check and inline best sol pointer Refactor model method to directly check SCIP stage (raise a Warning when in INIT or FREE) instead of using an intermediate boolean, and inline the SCIP_SOL* variable declaration when obtaining the best solution. Improves readability and removes an unnecessary temporary variable without changing behavior. * Add return type annotation to getVal Annotate Model.getVal in src/pyscipopt/scip.pxi with an explicit return type Union[float, np.ndarray] and remove an extra space in the parameter list. This clarifies that getVal may return either a scalar float or a NumPy array and improves static typing and IDE/type-checker support. * Add test for Model.getSolVal and getVal Add test_getSolVal to cover issue #1136. The test creates a Model with a binary scalar and a binary matrix variable, sets an objective, optimizes and obtains the best solution. It asserts that getSolVal(sol, var) matches getVal(var) for both scalar and matrix variables (values expected to be zeros), and verifies that passing a non-variable to getVal or getSolVal raises TypeError. * Changelog: Model.getSolVal supports MatrixExpr Add a CHANGELOG entry noting that Model.getSolVal now accepts MatrixExpr. This documents the new support for MatrixExpr inputs in the Model.getSolVal API. * Use numpy array comparisons in tests Replace usages of built-in all(...) with np.array_equal(...) to correctly compare array-like results from m.getSolVal and m.getVal. Import numpy as np and construct the expected value as np.array([0, 0]). Also remove an unused itertools import and reorder pyscipopt imports for clarity. * Refine type hints: use np alias and add overloads Improve typing in src/pyscipopt/scip.pyi: import Union and overload, alias numpy as np, and switch class bases from numpy.ndarray to np.ndarray. Add overloads for Model.getSolVal and Model.getVal to distinguish scalar returns (Expr/GenExpr -> float) from matrix returns (MatrixExpr -> np.ndarray), improving type accuracy for consumers and IDE/type-checkers. * Reorder imports in tests/test_model.py Move 'from helpers.utils import random_mip_1' below the pyscipopt imports to group third-party and local imports (style-only change; no functional impact). * Apply suggestion from @Joao-Dionisio --------- Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent a03b00e commit ef78d29

File tree

4 files changed

+67
-37
lines changed

4 files changed

+67
-37
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Wrapped isObjIntegral() and test
99
- Added structured_optimization_trace recipe for structured optimization progress tracking
1010
- Added methods: getPrimalDualIntegral()
11+
- getSolVal() supports MatrixExpr now
1112
### Fixed
1213
- getBestSol() now returns None for infeasible problems instead of a Solution with NULL pointer
1314
- all fundamental callbacks now raise an error if not implemented

src/pyscipopt/scip.pxi

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,7 +1098,16 @@ cdef class Solution:
10981098
sol.scip = scip
10991099
return sol
11001100

1101-
def __getitem__(self, expr: Union[Expr, MatrixExpr]):
1101+
def __getitem__(
1102+
self,
1103+
expr: Union[Expr, GenExpr, MatrixExpr],
1104+
) -> Union[float, np.ndarray]:
1105+
if not isinstance(expr, (Expr, GenExpr, MatrixExpr)):
1106+
raise TypeError(
1107+
"Argument 'expr' has incorrect type, expected 'Expr', 'GenExpr', or "
1108+
f"'MatrixExpr', got {type(expr).__name__!r}"
1109+
)
1110+
11021111
self._checkStage("SCIPgetSolVal")
11031112
return expr._evaluate(self)
11041113

@@ -10968,75 +10977,63 @@ cdef class Model:
1096810977
def getSolVal(
1096910978
self,
1097010979
Solution sol,
10971-
expr: Union[Expr, GenExpr],
10980+
expr: Union[Expr, GenExpr, MatrixExpr],
1097210981
) -> Union[float, np.ndarray]:
1097310982
"""
10974-
Retrieve value of given variable or expression in the given solution or in
10975-
the LP/pseudo solution if sol == None
10983+
Retrieve value of given variable or expression in the given solution.
1097610984

1097710985
Parameters
1097810986
----------
1097910987
sol : Solution
10980-
expr : Expr
10981-
polynomial expression to query the value of
10988+
Solution to query the value from. If None, the current LP/pseudo solution is
10989+
used.
10990+
10991+
expr : Expr, GenExpr, MatrixExpr
10992+
Expression to query the value of.
1098210993

1098310994
Returns
1098410995
-------
10985-
float
10996+
float or np.ndarray
1098610997

1098710998
Notes
1098810999
-----
1098911000
A variable is also an expression.
1099011001

1099111002
"""
10992-
if not isinstance(expr, (Expr, GenExpr)):
10993-
raise TypeError(
10994-
"Argument 'expr' has incorrect type (expected 'Expr' or 'GenExpr', "
10995-
f"got {type(expr)})"
10996-
)
1099711003
# no need to create a NULL solution wrapper in case we have a variable
1099811004
return (sol or Solution.create(self._scip, NULL))[expr]
1099911005

11000-
def getVal(self, expr: Union[Expr, GenExpr, MatrixExpr] ):
11006+
def getVal(self, expr: Union[Expr, GenExpr, MatrixExpr]) -> Union[float, np.ndarray]:
1100111007
"""
1100211008
Retrieve the value of the given variable or expression in the best known solution.
1100311009
Can only be called after solving is completed.
1100411010

1100511011
Parameters
1100611012
----------
1100711013
expr : Expr, GenExpr or MatrixExpr
11014+
Expression to query the value of.
1100811015

1100911016
Returns
1101011017
-------
11011-
float
11018+
float or np.ndarray
1101211019

1101311020
Notes
1101411021
-----
1101511022
A variable is also an expression.
1101611023

1101711024
"""
11018-
cdef SCIP_SOL* current_best_sol
11019-
11020-
stage_check = SCIPgetStage(self._scip) not in [SCIP_STAGE_INIT, SCIP_STAGE_FREE]
11021-
if not stage_check:
11025+
if SCIPgetStage(self._scip) in {SCIP_STAGE_INIT, SCIP_STAGE_FREE}:
1102211026
raise Warning("Method cannot be called in stage ", self.getStage())
1102311027

1102411028
# Ensure _bestSol is up-to-date (cheap pointer comparison)
11025-
current_best_sol = SCIPgetBestSol(self._scip)
11029+
cdef SCIP_SOL* current_best_sol = SCIPgetBestSol(self._scip)
1102611030
if self._bestSol is None or self._bestSol.sol != current_best_sol:
1102711031
self._bestSol = Solution.create(self._scip, current_best_sol)
1102811032

1102911033
if self._bestSol.sol == NULL and SCIPgetStage(self._scip) != SCIP_STAGE_SOLVING:
1103011034
raise Warning("No solution available")
1103111035

11032-
if isinstance(expr, MatrixExpr):
11033-
result = np.empty(expr.shape, dtype=float)
11034-
for idx in np.ndindex(result.shape):
11035-
result[idx] = self.getSolVal(self._bestSol, expr[idx])
11036-
else:
11037-
result = self.getSolVal(self._bestSol, expr)
11038-
11039-
return result
11036+
return self._bestSol[expr]
1104011037

1104111038
def hasPrimalRay(self):
1104211039
"""

src/pyscipopt/scip.pyi

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import ClassVar
1+
from typing import ClassVar, Union, overload
22

3-
import numpy
3+
import numpy as np
44
from _typeshed import Incomplete
55
from typing_extensions import disjoint_base
66

@@ -496,7 +496,7 @@ class LP:
496496
def solve(self, dual: Incomplete = ...) -> Incomplete: ...
497497
def writeLP(self, filename: Incomplete) -> Incomplete: ...
498498

499-
class MatrixConstraint(numpy.ndarray):
499+
class MatrixConstraint(np.ndarray):
500500
def getConshdlrName(self) -> Incomplete: ...
501501
def isActive(self) -> Incomplete: ...
502502
def isChecked(self) -> Incomplete: ...
@@ -512,7 +512,7 @@ class MatrixConstraint(numpy.ndarray):
512512
def isSeparated(self) -> Incomplete: ...
513513
def isStickingAtNode(self) -> Incomplete: ...
514514

515-
class MatrixExpr(numpy.ndarray):
515+
class MatrixExpr(np.ndarray):
516516
def _evaluate(self, sol: Incomplete) -> Incomplete: ...
517517
def __array_ufunc__(
518518
self,
@@ -522,7 +522,7 @@ class MatrixExpr(numpy.ndarray):
522522
**kwargs: Incomplete,
523523
) -> Incomplete: ...
524524

525-
class MatrixExprCons(numpy.ndarray):
525+
class MatrixExprCons(np.ndarray):
526526
def __array_ufunc__(
527527
self,
528528
ufunc: Incomplete,
@@ -1215,7 +1215,10 @@ class Model:
12151215
self, sol: Incomplete, original: Incomplete = ...
12161216
) -> Incomplete: ...
12171217
def getSolTime(self, sol: Incomplete) -> Incomplete: ...
1218-
def getSolVal(self, sol: Incomplete, expr: Incomplete) -> Incomplete: ...
1218+
@overload
1219+
def getSolVal(self, sol: Solution, expr: Union[Expr, GenExpr]) -> float: ...
1220+
@overload
1221+
def getSolVal(self, sol: Solution, expr: MatrixExpr) -> np.ndarray: ...
12191222
def getSols(self) -> Incomplete: ...
12201223
def getSolvingTime(self) -> Incomplete: ...
12211224
def getStage(self) -> Incomplete: ...
@@ -1227,7 +1230,10 @@ class Model:
12271230
def getTransformedCons(self, cons: Incomplete) -> Incomplete: ...
12281231
def getTransformedVar(self, var: Incomplete) -> Incomplete: ...
12291232
def getTreesizeEstimation(self) -> Incomplete: ...
1230-
def getVal(self, expr: Incomplete) -> Incomplete: ...
1233+
@overload
1234+
def getVal(self, expr: Union[Expr, GenExpr]) -> float: ...
1235+
@overload
1236+
def getVal(self, expr: MatrixExpr) -> np.ndarray: ...
12311237
def getValsLinear(self, cons: Incomplete) -> Incomplete: ...
12321238
def getVarDict(self, transformed: Incomplete = ...) -> Incomplete: ...
12331239
def getVarLbDive(self, var: Incomplete) -> Incomplete: ...

tests/test_model.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
import pytest
2-
import os
31
import itertools
2+
import os
3+
4+
import numpy as np
5+
import pytest
46

5-
from pyscipopt import Model, SCIP_STAGE, SCIP_PARAMSETTING, SCIP_BRANCHDIR, quicksum
7+
from pyscipopt import SCIP_BRANCHDIR, SCIP_PARAMSETTING, SCIP_STAGE, Model, quicksum
68
from helpers.utils import random_mip_1
79

10+
811
def test_model():
912
# create solver instance
1013
s = Model()
@@ -616,3 +619,26 @@ def create_model_and_get_objects():
616619

617620
assert repr(x) == ""
618621
assert repr(c) == ""
622+
623+
624+
def test_getSolVal():
625+
# fix #1136
626+
627+
m = Model()
628+
x = m.addVar(vtype="B")
629+
y = m.addMatrixVar(2, vtype="B")
630+
631+
m.setObjective(x + y.sum())
632+
m.optimize()
633+
sol = m.getBestSol()
634+
635+
assert m.getSolVal(sol, x) == m.getVal(x)
636+
assert m.getVal(x) == 0
637+
638+
assert np.array_equal(m.getSolVal(sol, y), m.getVal(y))
639+
assert np.array_equal(m.getVal(y), np.array([0, 0]))
640+
641+
with pytest.raises(TypeError):
642+
m.getVal("not_a_var")
643+
with pytest.raises(TypeError):
644+
m.getSolVal(sol, "not_a_var")

0 commit comments

Comments
 (0)