diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 5a8220db0c..1d77553fec 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -586,10 +586,8 @@ cdef class usm_ndarray: return _flags.Flags(self, self.flags_) cdef _set_writable_flag(self, int flag): - cdef int arr_fl = self.flags_ - arr_fl ^= (arr_fl & USM_ARRAY_WRITABLE) # unset WRITABLE flag - arr_fl |= (USM_ARRAY_WRITABLE if flag else 0) - self.flags_ = arr_fl + cdef int mask = (USM_ARRAY_WRITABLE if flag else 0) + self.flags_ = _copy_writable(self.flags_, mask) @property def usm_type(self): diff --git a/dpctl/tests/elementwise/test_elementwise_classes.py b/dpctl/tests/elementwise/test_elementwise_classes.py index c8680078b7..9823c30a66 100644 --- a/dpctl/tests/elementwise/test_elementwise_classes.py +++ b/dpctl/tests/elementwise/test_elementwise_classes.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest + import dpctl.tensor as dpt from dpctl.tests.helper import get_queue_or_skip @@ -49,6 +51,15 @@ def test_unary_class_str_repr(): assert kl_n in r +def test_unary_read_only_out(): + get_queue_or_skip() + x = dpt.arange(32, dtype=dpt.int32) + r = dpt.empty_like(x) + r.flags["W"] = False + with pytest.raises(ValueError): + unary_fn(x, out=r) + + def test_binary_class_getters(): fn = binary_fn.get_implementation_function() assert callable(fn) @@ -105,3 +116,13 @@ def test_binary_class_nout(): nout = binary_fn.nout assert isinstance(nout, int) assert nout == 1 + + +def test_biary_read_only_out(): + get_queue_or_skip() + x1 = dpt.ones(32, dtype=dpt.float32) + x2 = dpt.ones_like(x1) + r = dpt.empty_like(x1) + r.flags["W"] = False + with pytest.raises(ValueError): + binary_fn(x1, x2, out=r) diff --git a/dpctl/tests/test_tensor_clip.py b/dpctl/tests/test_tensor_clip.py index 11c93ecf1f..3aaf40a3f8 100644 --- a/dpctl/tests/test_tensor_clip.py +++ b/dpctl/tests/test_tensor_clip.py @@ -748,3 +748,22 @@ def test_clip_compute_follows_data(): with pytest.raises(ExecutionPlacementError): dpt.clip(x, out=res) + + +def test_clip_readonly_out(): + get_queue_or_skip() + x = dpt.arange(32, dtype=dpt.int32) + r = dpt.empty_like(x) + r.flags["W"] = False + + with pytest.raises(ValueError): + dpt.clip(x, min=0, max=10, out=r) + + with pytest.raises(ValueError): + dpt.clip(x, max=10, out=r) + + with pytest.raises(ValueError): + dpt.clip(x, min=0, out=r) + + with pytest.raises(ValueError): + dpt.clip(x, out=r) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index c0195a02f2..e51c0a2ac7 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -332,6 +332,16 @@ def test_matmul_out(): assert np.allclose(ref, dpt.asnumpy(res)) +def test_matmul_readonly_out(): + get_queue_or_skip() + m = dpt.ones((10, 10), dtype=dpt.int32) + r = dpt.empty_like(m) + r.flags["W"] = False + + with pytest.raises(ValueError): + dpt.matmul(m, m, out=r) + + def test_matmul_dtype(): get_queue_or_skip()