Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def transpose(self, *axes):
return self

axes_len = len(axes)
if axes_len == 1 and isinstance(axes[0], tuple):
if axes_len == 1 and isinstance(axes[0], (tuple, list)):
axes = axes[0]

res = self.__new__(dpnp_array)
Expand Down
2 changes: 1 addition & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1909,7 +1909,7 @@ def transpose(a, axes=None):
if isinstance(a, dpnp_array):
array = a
elif isinstance(a, dpt.usm_ndarray):
array = dpnp_array._create_from_usm_ndarray(a.get_array())
array = dpnp_array._create_from_usm_ndarray(a)
else:
raise TypeError(
f"An array must be any of supported type, but got {type(a)}"
Expand Down
24 changes: 22 additions & 2 deletions tests/test_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_unique(array):


class TestTranspose:
@pytest.mark.parametrize("axes", [(0, 1), (1, 0)])
@pytest.mark.parametrize("axes", [(0, 1), (1, 0), [0, 1]])
def test_2d_with_axes(self, axes):
na = numpy.array([[1, 2], [3, 4]])
da = dpnp.array(na)
Expand All @@ -124,7 +124,22 @@ def test_2d_with_axes(self, axes):
result = dpnp.transpose(da, axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("axes", [(1, 0, 2), ((1, 0, 2),)])
# ndarray
expected = na.transpose(axes)
result = da.transpose(axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize(
"axes",
[
(1, 0, 2),
[1, 0, 2],
((1, 0, 2),),
([1, 0, 2],),
[(1, 0, 2)],
[[1, 0, 2]],
],
)
def test_3d_with_packed_axes(self, axes):
na = numpy.ones((1, 2, 3))
da = dpnp.array(na)
Expand All @@ -133,6 +148,11 @@ def test_3d_with_packed_axes(self, axes):
result = da.transpose(*axes)
assert_array_equal(expected, result)

# ndarray
expected = na.transpose(*axes)
result = da.transpose(*axes)
assert_array_equal(expected, result)

@pytest.mark.parametrize("shape", [(10,), (2, 4), (5, 3, 7), (3, 8, 4, 1)])
def test_none_axes(self, shape):
na = numpy.ones(shape)
Expand Down
22 changes: 20 additions & 2 deletions tests/third_party/cupy/manipulation_tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,20 +64,38 @@ def test_moveaxis_invalid2_2(self):
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, [0, -4], [1, 2])

def test_moveaxis_invalid2_3(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(numpy.AxisError):
xp.moveaxis(a, -4, 0)

# len(source) != len(destination)
def test_moveaxis_invalid3(self):
def test_moveaxis_invalid3_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1, 2], [1, 2])

def test_moveaxis_invalid3_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, 0, [1, 2])

# len(source) != len(destination)
def test_moveaxis_invalid4(self):
def test_moveaxis_invalid4_1(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], [1, 2, 0])

def test_moveaxis_invalid4_2(self):
for xp in (numpy, cupy):
a = testing.shaped_arange((2, 3, 4), xp)
with pytest.raises(ValueError):
xp.moveaxis(a, [0, 1], 1)

# Use the same axis twice
def test_moveaxis_invalid5_1(self):
for xp in (numpy, cupy):
Expand Down