diff --git a/dpnp/backend/extensions/fft/in_place.cpp b/dpnp/backend/extensions/fft/in_place.cpp index 9256d022efcc..14abf62caa84 100644 --- a/dpnp/backend/extensions/fft/in_place.cpp +++ b/dpnp/backend/extensions/fft/in_place.cpp @@ -47,7 +47,7 @@ std::pair const bool is_forward, const std::vector &depends) { - bool committed = descr.is_committed(); + const bool committed = descr.is_committed(); if (!committed) { throw py::value_error("Descriptor is not committed"); } diff --git a/dpnp/backend/extensions/fft/out_of_place.cpp b/dpnp/backend/extensions/fft/out_of_place.cpp index 71474d058e76..222927322076 100644 --- a/dpnp/backend/extensions/fft/out_of_place.cpp +++ b/dpnp/backend/extensions/fft/out_of_place.cpp @@ -49,7 +49,7 @@ std::pair const bool is_forward, const std::vector &depends) { - bool committed = descr.is_committed(); + const bool committed = descr.is_committed(); if (!committed) { throw py::value_error("Descriptor is not committed"); } @@ -93,8 +93,8 @@ std::pair if (in_nd > 1) { for (int i = 0; i < in_nd - 1; ++i) { if (in_shape[i] != out_shape[i]) { - throw py::value_error("The shape of the input and output " - "arrays must be the same."); + throw py::value_error("The shape of the output array is not " + "correct for the given input array."); } in_size *= in_shape[i]; } @@ -105,8 +105,9 @@ std::pair // r2c FFT N = m / 2 + 1; // integer divide if (n != N) { - throw py::value_error("The shape of the output array is not " - "correct for real to complex transform."); + throw py::value_error( + "The shape of the output array is not correct for the given " + "input array in real to complex FFT transform."); } } else { @@ -114,8 +115,8 @@ std::pair // have the same size as output before calling this function N = m; if (n != N) { - throw py::value_error("The shape of the input array must be " - "the same as the shape of the output array."); + throw py::value_error("The shape of the output array is not " + "correct for the given input array."); } } diff --git a/dpnp/dpnp_utils/dpnp_algo_utils.pxd b/dpnp/dpnp_utils/dpnp_algo_utils.pxd index 23714b5218cc..563ed6a35a81 100644 --- a/dpnp/dpnp_utils/dpnp_algo_utils.pxd +++ b/dpnp/dpnp_utils/dpnp_algo_utils.pxd @@ -37,24 +37,12 @@ cpdef checker_throw_value_error(function_name, param_name, param, expected) """ -cpdef checker_throw_axis_error(function_name, param_name, param, expected) -""" Throw exception AxisError if 'param' is not 'expected' - -""" - - cpdef checker_throw_type_error(function_name, given_type) """ Throw exception TypeError if 'given_type' type is not supported """ -cpdef checker_throw_index_error(function_name, index, size) -""" Throw exception IndexError if 'index' is out of bounds - -""" - - cpdef cpp_bool use_origin_backend(input1=*, size_t compute_size=*) """ This function needs to redirect particular computation cases to original backend @@ -69,17 +57,7 @@ Return: cpdef tuple _object_to_tuple(object obj) -cdef int _normalize_order(order, cpp_bool allow_k=*) except? 0 - -cpdef shape_type_c normalize_axis(object axis, size_t shape_size) -""" -Conversion of the transformation shape axis [-1, 0, 1] into [2, 0, 1] where numbers are `id`s of array shape axis -""" -cpdef long _get_linear_index(key, tuple shape, int ndim) -""" -Compute linear index of an element in memory from array indices -""" cpdef tuple get_axis_offsets(shape) """ diff --git a/dpnp/dpnp_utils/dpnp_algo_utils.pyx b/dpnp/dpnp_utils/dpnp_algo_utils.pyx index b8377368d5fd..11583fef9c5b 100644 --- a/dpnp/dpnp_utils/dpnp_algo_utils.pyx +++ b/dpnp/dpnp_utils/dpnp_algo_utils.pyx @@ -59,8 +59,6 @@ Python import functions """ __all__ = [ "call_origin", - "checker_throw_axis_error", - "checker_throw_index_error", "checker_throw_type_error", "checker_throw_value_error", "create_output_descriptor_py", @@ -68,9 +66,7 @@ __all__ = [ "dpnp_descriptor", "get_axis_offsets", "get_usm_allocations", - "_get_linear_index", "map_dtype_to_device", - "normalize_axis", "_object_to_tuple", "unwrap_array", "use_origin_backend" @@ -308,17 +304,6 @@ def map_dtype_to_device(dtype, device): raise RuntimeError(f"Unrecognized type of input dtype={dtype}") -cpdef checker_throw_axis_error(function_name, param_name, param, expected): - err_msg = f"{ERROR_PREFIX} in function {function_name}()" - err_msg += f" axes '{param_name}' expected `{expected}`, but '{param}' provided" - raise AxisError(err_msg) - - -cpdef checker_throw_index_error(function_name, index, size): - raise IndexError( - f"{ERROR_PREFIX} in function {function_name}() index {index} is out of bounds. dimension size `{size}`") - - cpdef checker_throw_type_error(function_name, given_type): raise TypeError(f"{ERROR_PREFIX} in function {function_name}() type '{given_type}' is not supported") @@ -364,22 +349,6 @@ cpdef tuple get_axis_offsets(shape): return _object_to_tuple(result) -cpdef long _get_linear_index(key, tuple shape, int ndim): - """ - Compute linear index of an element in memory from array indices - """ - - if isinstance(key, tuple): - li = 0 - m = 1 - for i in range(ndim - 1, -1, -1): - li += key[i] * m - m *= shape[i] - else: - li = key - return li - - cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape, DPNPFuncType c_type, dpnp_descriptor requested_out, @@ -412,53 +381,6 @@ cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape, return result_desc -cpdef shape_type_c normalize_axis(object axis_obj, size_t shape_size_inp): - """ - Conversion of the transformation shape axis [-1, 0, 1] into [2, 0, 1] where numbers are `id`s of array shape axis - """ - - cdef shape_type_c axis = _object_to_tuple(axis_obj) # axis_obj might be a scalar - cdef ssize_t shape_size = shape_size_inp # convert type for comparison with axis id - - cdef size_t axis_size = axis.size() - cdef shape_type_c result = shape_type_c(axis_size, 0) - for i in range(axis_size): - if (axis[i] >= shape_size) or (axis[i] < -shape_size): - checker_throw_axis_error("normalize_axis", "axis", axis[i], shape_size - 1) - - if (axis[i] < 0): - result[i] = shape_size + axis[i] - else: - result[i] = axis[i] - - return result - - -@cython.profile(False) -cdef inline int _normalize_order(order, cpp_bool allow_k=True) except? 0: - """ Converts memory order letters to some common view - - """ - - cdef int order_type - order_type = b'C' if len(order) == 0 else ord(order[0]) - - if order_type == b'K' or order_type == b'k': - if not allow_k: - raise ValueError("DPNP _normalize_order(): order \'K\' is not permitted") - order_type = b'K' - elif order_type == b'A' or order_type == b'a': - order_type = b'A' - elif order_type == b'C' or order_type == b'c': - order_type = b'C' - elif order_type == b'F' or order_type == b'f': - order_type = b'F' - else: - raise TypeError("DPNP _normalize_order(): order is not understood") - - return order_type - - @cython.profile(False) cpdef inline tuple _object_to_tuple(object obj): """ Converts Python object into tuple diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index 2c06bede5d5c..be97dafed170 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -36,18 +36,8 @@ """ -# pylint: disable=invalid-name - -import numpy - import dpnp -# pylint: disable=no-name-in-module -from dpnp.dpnp_utils import ( - call_origin, - checker_throw_axis_error, -) - from .dpnp_utils_fft import ( dpnp_fft, dpnp_fftn, @@ -122,8 +112,8 @@ def fft(a, n=None, axis=-1, norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype. + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `n`) and dtype. Default: ``None``. Returns @@ -209,9 +199,8 @@ def fft2(a, s=None, axes=(-2, -1), norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype (and hence is incompatible with - passing in all but the trivial `s`). + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `s`) and dtype. Default: ``None``. Returns @@ -263,7 +252,9 @@ def fft2(a, s=None, axes=(-2, -1), norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fftn(a, forward=True, s=s, axes=axes, norm=norm, out=out) + return dpnp_fftn( + a, forward=True, real=False, s=s, axes=axes, norm=norm, out=out + ) def fftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None): @@ -357,25 +348,25 @@ def fftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None): usm_type=usm_type, sycl_queue=sycl_queue, ) - N = (n - 1) // 2 + 1 + m = (n - 1) // 2 + 1 p1 = dpnp.arange( 0, - N, + m, dtype=dpnp.intp, device=device, usm_type=usm_type, sycl_queue=sycl_queue, ) - results[:N] = p1 + results[:m] = p1 p2 = dpnp.arange( - -(n // 2), + m - n, 0, dtype=dpnp.intp, device=device, usm_type=usm_type, sycl_queue=sycl_queue, ) - results[N:] = p2 + results[m:] = p2 return results * val @@ -419,9 +410,8 @@ def fftn(a, s=None, axes=None, norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype (and hence is incompatible with - passing in all but the trivial `s`). + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `s`) and dtype. Default: ``None``. Returns @@ -476,7 +466,9 @@ def fftn(a, s=None, axes=None, norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fftn(a, forward=True, s=s, axes=axes, norm=norm, out=out) + return dpnp_fftn( + a, forward=True, real=False, s=s, axes=axes, norm=norm, out=out + ) def fftshift(x, axes=None): @@ -679,8 +671,8 @@ def ifft(a, n=None, axis=-1, norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype. + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `n`) and dtype. Default: ``None``. Returns @@ -769,9 +761,8 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype (and hence is incompatible with - passing in all but the trivial `s`). + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `s`) and dtype. Default: ``None``. Returns @@ -815,7 +806,9 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fftn(a, forward=False, s=s, axes=axes, norm=norm, out=out) + return dpnp_fftn( + a, forward=False, real=False, s=s, axes=axes, norm=norm, out=out + ) def ifftn(a, s=None, axes=None, norm=None, out=None): @@ -867,9 +860,8 @@ def ifftn(a, s=None, axes=None, norm=None, out=None): the default option ``"backward"``. Default: ``"backward"``. out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional - If provided, the result will be placed in this array. It should be - of the appropriate shape and dtype (and hence is incompatible with - passing in all but the trivial `s`). + If provided, the result will be placed in this array. It should be of + the appropriate shape (consistent with the choice of `s`) and dtype. Default: ``None``. Returns @@ -912,7 +904,9 @@ def ifftn(a, s=None, axes=None, norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fftn(a, forward=False, s=s, axes=axes, norm=norm, out=out) + return dpnp_fftn( + a, forward=False, real=False, s=s, axes=axes, norm=norm, out=out + ) def ifftshift(x, axes=None): @@ -1147,95 +1141,209 @@ def irfft(a, n=None, axis=-1, norm=None, out=None): ) -def irfft2(x, s=None, axes=(-2, -1), norm=None): +def irfft2(a, s=None, axes=(-2, -1), norm=None, out=None): """ - Compute the 2-dimensional inverse discrete Fourier Transform for real input. - - Multi-dimensional arrays computed as batch of 1-D arrays. + Computes the inverse of :obj:`dpnp.fft.rfft2`. For full documentation refer to :obj:`numpy.fft.irfft2`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array, can be complex. + s : {None, sequence of ints}, optional + Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + If it is ``-1``, the whole input is used (no padding/trimming). + If `s` is not given, the shape of the input along the axes + specified by `axes` is used. Except for the last axis which is taken to + be ``2*(m-1)`` where `m` is the length of the input along that axis. + If `s` is not ``None``, `axes` must not be ``None`` + Default: ``None``. + axes : {None, sequence of ints}, optional + Axes over which to compute the inverse FFT. If not given, the last + ``len(s)`` axes are used, or all axes if `s` is also not specified. + Repeated indices in `axes` means that the transform over that axis is + performed multiple times. If `s` is specified, the corresponding `axes` + to be transformed must be explicitly specified too. A one-element + sequence means that a one-dimensional FFT is performed. An empty + sequence means that no FFT is performed. + Default: ``(-2, -1)``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray or usm_ndarray}, optional + If provided, the result will be placed in this array. It should be of + the appropriate dtype and shape for the last transformation + (consistent with the choice of `s`). + Default: ``None``. - """ + Returns + ------- + out : dpnp.ndarray + The truncated or zero-padded input, transformed along the axes + indicated by `axes`, or the last two axes if `axes` is not given. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - if x_desc: - if norm is not None: - pass - else: - return irfftn(x_desc.get_pyobj(), s, axes, norm) + See Also + -------- + :obj:`dpnp.fft` : Overall view of discrete Fourier transforms, with + definitions and conventions used. + :obj:`dpnp.fft.rfft2` : The forward two-dimensional FFT of real input, + of which :obj:`dpnp.fft.irfft2` is the inverse. + :obj:`dpnp.fft.rfft` : The one-dimensional FFT for real input. + :obj:`dpnp.fft.irfft` : The inverse of the one-dimensional FFT of + real input. + :obj:`dpnp.fft.irfftn` : The inverse of the *N*-dimensional FFT of + real input. - return call_origin(numpy.fft.irfft2, x, s, axes, norm) + Notes + ----- + :obj:`dpnp.fft.irfft2` is just :obj:`dpnp.fft.irfftn` with a different + default for `axes`. For more details see :obj:`dpnp.fft.irfftn`. + Examples + -------- + >>> import dpnp as np + >>> a = np.mgrid[:5, :5][0] + >>> A = np.fft.rfft2(a) + >>> np.fft.irfft2(A, s=a.shape) + array([[0., 0., 0., 0., 0.], + [1., 1., 1., 1., 1.], + [2., 2., 2., 2., 2.], + [3., 3., 3., 3., 3.], + [4., 4., 4., 4., 4.]]) -def irfftn(x, s=None, axes=None, norm=None): """ - Compute the N-dimensional inverse discrete Fourier Transform for real input. - Multi-dimensional arrays computed as batch of 1-D arrays. + dpnp.check_supported_arrays_type(a) + return dpnp_fftn( + a, forward=False, real=True, s=s, axes=axes, norm=norm, out=out + ) + + +def irfftn(a, s=None, axes=None, norm=None, out=None): + """ + Computes the inverse of :obj:`dpnp.fft.rfftn`. + + This function computes the inverse of the *N*-dimensional discrete Fourier + Transform for real input over any number of axes in an *M*-dimensional + array by means of the Fast Fourier Transform (FFT). In other words, + ``irfftn(rfftn(a), a.shape) == a`` to within numerical accuracy. (The + ``a.shape`` is necessary like ``len(a)`` is for :obj:`dpnp.fft.irfft`, + and for the same reason.) + + The input should be ordered in the same way as is returned by + :obj:`dpnp.fft.rfftn`, i.e. as for :obj:`dpnp.fft.irfft` for the final + transformation axis, and as for :obj:`dpnp.fft.irfftn` along all the other + axes. For full documentation refer to :obj:`numpy.fft.irfftn`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array, can be complex. + s : {None, sequence of ints}, optional + Shape (length of each transformed axis) of the output + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). `s` is also the + number of input points used along this axis, except for the last axis, + where ``s[-1]//2+1`` points of the input are used. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + If it is ``-1``, the whole input is used (no padding/trimming). + If `s` is not given, the shape of the input along the axes + specified by axes is used. Except for the last axis which is taken to + be ``2*(m-1)`` where `m` is the length of the input along that axis. + If `s` is not ``None``, `axes` must not be ``None`` + Default: ``None``. + axes : {None, sequence of ints}, optional + Axes over which to compute the inverse FFT. If not given, the last + ``len(s)`` axes are used, or all axes if `s` is also not specified. + Repeated indices in `axes` means that the transform over that axis is + performed multiple times. If `s` is specified, the corresponding `axes` + to be transformed must be explicitly specified too. A one-element + sequence means that a one-dimensional FFT is performed. An empty + sequence means that no FFT is performed. + Default: ``None``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray or usm_ndarray}, optional + If provided, the result will be placed in this array. It should be of + the appropriate dtype and shape for the last transformation + (consistent with the choice of `s`). + Default: ``None``. + + Returns + ------- + out : dpnp.ndarray + The truncated or zero-padded input, transformed along the axes + indicated by `axes`, or by a combination of `s` and `a`, + as explained in the parameters section above. + The length of each transformed axis is as given by the corresponding + element of `s`, or the length of the input in every axis except for the + last one if `s` is not given. In the final transformed axis the length + of the output when `s` is not given is ``2*(m-1)`` where `m` is the + length of the final transformed axis of the input. To get an odd + number of output points in the final axis, `s` must be specified. + + See Also + -------- + :obj:`dpnp.fft` : Overall view of discrete Fourier transforms, with + definitions and conventions used. + :obj:`dpnp.fft.rfftn` : The `n`-dimensional FFT of real input. + :obj:`dpnp.fft.fft` : The one-dimensional FFT, with definitions and + conventions used. + :obj:`dpnp.fft.irfft` : The inverse of the one-dimensional FFT of + real input. + :obj:`dpnp.fft.irfft2` : The inverse of the two-dimensional FFT of + real input. + + Notes + ----- + See :obj:`dpnp.fft` for details, definitions and conventions used. + + See :obj:`dpnp.fft.rfft` for definitions and conventions used for real + input. + + The correct interpretation of the Hermitian input depends on the shape of + the original data, as given by `s`. This is because each input shape could + correspond to either an odd or even length signal. By default, + :obj:`dpnp.fft.irfftn` assumes an even output length which puts the last + entry at the Nyquist frequency; aliasing with its symmetric counterpart. + When performing the final complex to real transform, the last value is thus + treated as purely real. To avoid losing information, the correct shape of + the real input **must** be given. + + Examples + -------- + >>> import dpnp as np + >>> a = np.zeros((3, 2, 2)) + >>> a[0, 0, 0] = 3 * 2 * 2 + >>> np.fft.irfftn(a) + array([[[1., 1.], + [1., 1.]], + [[1., 1.], + [1., 1.]], + [[1., 1.], + [1., 1.]]]) """ - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - # TODO: enable implementation - # pylint: disable=condition-evals-to-constant - if x_desc and 0: - if s is None: - boundaries = tuple(x_desc.shape[i] for i in range(x_desc.ndim)) - else: - boundaries = s - - if axes is None: - axes_param = list(range(x_desc.ndim)) - else: - axes_param = axes - - if norm is not None: - pass - else: - x_iter = x - iteration_list = list(range(len(axes_param))) - iteration_list.reverse() # inplace operation - for it in iteration_list: - param_axis = axes_param[it] - try: - param_n = boundaries[param_axis] - except IndexError: - checker_throw_axis_error( - "fft.irfftn", - "is out of bounds", - param_axis, - f"< {len(boundaries)}", - ) - - x_iter_desc = dpnp.get_dpnp_descriptor(x_iter) - x_iter = irfft( - x_iter_desc.get_pyobj(), - n=param_n, - axis=param_axis, - norm=norm, - ) - - return x_iter - - return call_origin(numpy.fft.irfftn, x, s, axes, norm) + dpnp.check_supported_arrays_type(a) + return dpnp_fftn( + a, forward=False, real=True, s=s, axes=axes, norm=norm, out=out + ) def rfft(a, n=None, axis=-1, norm=None, out=None): @@ -1251,7 +1359,7 @@ def rfft(a, n=None, axis=-1, norm=None, out=None): Parameters ---------- a : {dpnp.ndarray, usm_ndarray} - Input array. + Input array, taken to be real. n : {None, int}, optional Number of points along transformation axis in the input to use. If `n` is smaller than the length of the input, the input is cropped. @@ -1329,32 +1437,83 @@ def rfft(a, n=None, axis=-1, norm=None, out=None): ) -def rfft2(x, s=None, axes=(-2, -1), norm=None): +def rfft2(a, s=None, axes=(-2, -1), norm=None, out=None): """ - Compute the 2-dimensional discrete Fourier Transform for real input. - - Multi-dimensional arrays computed as batch of 1-D arrays. + Compute the 2-dimensional FFT of a real array. For full documentation refer to :obj:`numpy.fft.rfft2`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array, taken to be real. + s : {None, sequence of ints}, optional + Shape (length of each transformed axis) to use from the input. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + The final element of `s` corresponds to `n` for ``rfft(x, n)``, while + for the remaining axes, it corresponds to `n` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + If it is ``-1``, the whole input is used (no padding/trimming). + If `s` is not given, the shape of the input along the axes specified + by `axes` is used. If `s` is not ``None``, `axes` must not be ``None`` + either. Default: ``None``. + axes : {None, sequence of ints}, optional + Axes over which to compute the FFT. If not given, the last two axes are + used. A repeated index in `axes` means the transform over that axis is + performed multiple times. If `s` is specified, the corresponding `axes` + to be transformed must be explicitly specified too. A one-element + sequence means that a one-dimensional FFT is performed. An empty + sequence means that no FFT is performed. + Default: ``(-2, -1)``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional + If provided, the result will be placed in this array. It should be of + the appropriate dtype and shape for the last transformation + (consistent with the choice of `s`). + Default: ``None``. + + Returns + ------- + out : dpnp.ndarray of complex dtype + The truncated or zero-padded input, transformed along the axes + indicated by `axes`, or the last two axes if `axes` is not given. - """ + See Also + -------- + :obj:`dpnp.fft` : Overall view of discrete Fourier transforms, with + definitions and conventions used. + :obj:`dpnp.fft.rfft` : The one-dimensional FFT of real input. + :obj:`dpnp.fft.rfftn` : The `n`-dimensional FFT of real input. + :obj:`dpnp.fft.irfft2` : The inverse two-dimensional real FFT. + + Notes + ----- + This is just :obj:`dpnp.fft.rfftn` with different default behavior. + For more details see :obj:`dpnp.fft.rfftn`. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - if x_desc: - if norm is not None: - pass - else: - return rfftn(x_desc.get_pyobj(), s, axes, norm) + Examples + -------- + >>> import dpnp as np + >>> a = np.mgrid[:5, :5][0] + >>> np.fft.rfft2(a) + array([[ 50. +0.j , 0. +0.j , 0. +0.j ], + [-12.5+17.20477401j, 0. +0.j , 0. +0.j ], + [-12.5 +4.0614962j , 0. +0.j , 0. +0.j ], + [-12.5 -4.0614962j , 0. +0.j , 0. +0.j ], + [-12.5-17.20477401j, 0. +0.j , 0. +0.j ]]) - return call_origin(numpy.fft.rfft2, x, s, axes, norm) + """ + + dpnp.check_supported_arrays_type(a) + return dpnp_fftn( + a, forward=True, real=True, s=s, axes=axes, norm=norm, out=out + ) def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None): @@ -1448,10 +1607,10 @@ def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None): if not dpnp.isscalar(d): raise ValueError("`d` should be an scalar") val = 1.0 / (n * d) - N = n // 2 + 1 + m = n // 2 + 1 results = dpnp.arange( 0, - N, + m, dtype=dpnp.intp, device=device, usm_type=usm_type, @@ -1460,68 +1619,104 @@ def rfftfreq(n, d=1.0, device=None, usm_type=None, sycl_queue=None): return results * val -def rfftn(x, s=None, axes=None, norm=None): +def rfftn(a, s=None, axes=None, norm=None, out=None): """ - Compute the N-dimensional discrete Fourier Transform for real input. + Compute the *N*-dimensional discrete Fourier Transform for real input. - Multi-dimensional arrays computed as batch of 1-D arrays. + This function computes the *N*-dimensional discrete Fourier Transform over + any number of axes in an *M*-dimensional real array by means of the Fast + Fourier Transform (FFT). By default, all axes are transformed, with the + real transform performed over the last axis, while the remaining + transforms are complex. For full documentation refer to :obj:`numpy.fft.rfftn`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array, taken to be real. + s : {None, sequence of ints}, optional + Shape (length of each transformed axis) to use from the input. + (``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.). + The final element of `s` corresponds to `n` for ``rfft(x, n)``, while + for the remaining axes, it corresponds to `n` for ``fft(x, n)``. + Along each axis, if the given shape is smaller than that of the input, + the input is cropped. If it is larger, the input is padded with zeros. + If it is ``-1``, the whole input is used (no padding/trimming). + If `s` is not given, the shape of the input along the axes specified + by `axes` is used. If `s` is not ``None``, `axes` must not be ``None`` + either. Default: ``None``. + axes : {None, sequence of ints}, optional + Axes over which to compute the FFT. If not given, the last ``len(s)`` + axes are used, or all axes if `s` is also not specified. + Repeated indices in `axes` means that the transform over that axis is + performed multiple times. If `s` is specified, the corresponding `axes` + to be transformed must be explicitly specified too. A one-element + sequence means that a one-dimensional FFT is performed. An empty + sequence means that no FFT is performed. + Default: ``None``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional + If provided, the result will be placed in this array. It should be of + the appropriate dtype and shape for the last transformation + (consistent with the choice of `s`). + Default: ``None``. + + Returns + ------- + out : dpnp.ndarray of complex dtype + The truncated or zero-padded input, transformed along the axes + indicated by `axes`, or by a combination of `s` and `a`, + as explained in the parameters section above. + The length of the last axis transformed will be ``s[-1]//2+1``, + while the remaining transformed axes will have lengths according to + `s`, or unchanged from the input. + + See Also + -------- + :obj:`dpnp.fft` : Overall view of discrete Fourier transforms, with + definitions and conventions used. + :obj:`dpnp.fft.irfftn` : The inverse of the *N*-dimensional FFT of + real input. + :obj:`dpnp.fft.fft` : The one-dimensional FFT of general (complex) input. + :obj:`dpnp.fft.rfft` : The one-dimensional FFT of real input. + :obj:`dpnp.fft.fftn` : The *N*-dimensional FFT. + :obj:`dpnp.fft.fftn` : The two-dimensional FFT. + + Notes + ----- + The transform for real input is performed over the last transformation + axis, as by :obj:`dpnp.fft.rfft`, then the transform over the remaining + axes is performed as by :obj:`dpnp.fft.fftn`. The order of the output + is as for :obj:`dpnp.fft.rfft` for the final transformation axis, and + as for :obj:`dpnp.fft.fftn` for the remaining transformation axes. + + See :obj:`dpnp.fft` for details, definitions and conventions used. + + Examples + -------- + >>> import dpnp as np + >>> a = np.ones((2, 2, 2)) + >>> np.fft.rfftn(a) + array([[[8.+0.j, 0.+0.j], # may vary + [0.+0.j, 0.+0.j]], + [[0.+0.j, 0.+0.j], + [0.+0.j, 0.+0.j]]]) + + >>> np.fft.rfftn(a, axes=(2, 0)) + array([[[4.+0.j, 0.+0.j], # may vary + [4.+0.j, 0.+0.j]], + [[0.+0.j, 0.+0.j], + [0.+0.j, 0.+0.j]]]) """ - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - # TODO: enable implementation - # pylint: disable=condition-evals-to-constant - if x_desc and 0: - if s is None: - boundaries = tuple(x_desc.shape[i] for i in range(x_desc.ndim)) - else: - boundaries = s - - if axes is None: - axes_param = list(range(x_desc.ndim)) - else: - axes_param = axes - - if norm is not None: - pass - elif len(axes) < 1: - pass # let fallback to handle exception - else: - x_iter = x - iteration_list = list(range(len(axes_param))) - iteration_list.reverse() # inplace operation - for it in iteration_list: - param_axis = axes_param[it] - try: - param_n = boundaries[param_axis] - except IndexError: - checker_throw_axis_error( - "fft.rfftn", - "is out of bounds", - param_axis, - f"< {len(boundaries)}", - ) - - x_iter_desc = dpnp.get_dpnp_descriptor( - x_iter, copy_when_nondefault_queue=False - ) - x_iter = rfft( - x_iter_desc.get_pyobj(), - n=param_n, - axis=param_axis, - norm=norm, - ) - - return x_iter - - return call_origin(numpy.fft.rfftn, x, s, axes, norm) + dpnp.check_supported_arrays_type(a) + return dpnp_fftn( + a, forward=True, real=True, s=s, axes=axes, norm=norm, out=out + ) diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index d1e015068c01..7aba49bf472c 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -71,7 +71,7 @@ def _check_norm(norm): ) -def _commit_descriptor(a, in_place, c2c, a_strides, index, axes): +def _commit_descriptor(a, forward, in_place, c2c, a_strides, index, batch_fft): """Commit the FFT descriptor for the input array.""" a_shape = a.shape @@ -91,16 +91,76 @@ def _commit_descriptor(a, in_place, c2c, a_strides, index, axes): dsc.fwd_strides = strides dsc.bwd_strides = dsc.fwd_strides dsc.transform_in_place = in_place - if axes is not None: # batch_fft + out_strides = dsc.bwd_strides[1:] + if batch_fft: dsc.fwd_distance = a_strides[0] - dsc.bwd_distance = dsc.fwd_distance - dsc.number_of_transforms = numpy.prod(a_shape[0]) + if c2c: + dsc.bwd_distance = dsc.fwd_distance + elif dsc.fwd_strides[-1] == 1: + if forward: + dsc.bwd_distance = shape[-1] // 2 + 1 + else: + dsc.bwd_distance = dsc.fwd_distance + else: + dsc.bwd_distance = dsc.fwd_distance + dsc.number_of_transforms = a_shape[0] # batch_size + out_strides.insert(0, dsc.bwd_distance) + dsc.commit(a.sycl_queue) - return dsc + return dsc, out_strides + + +def _complex_nd_fft(a, s, norm, out, forward, in_place, c2c, axes, batch_fft): + """Computes complex-to-complex FFT of the input N-D array.""" + + len_axes = len(axes) + # OneMKL supports up to 3-dimensional FFT on GPU + # repeated axis in OneMKL FFT is not allowed + if len_axes > 3 or len(set(axes)) < len_axes: + axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3) + for i, (s_chunk, a_chunk) in enumerate(zip(shape_chunk, axes_chunk)): + a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk) + # if out is used in an intermediate step, it will have memory + # overlap with input and cannot be used in the final step (a new + # result array will be created for the final step), so there is no + # benefit in using out in an intermediate step + if i == len(axes_chunk) - 1: + tmp_out = out + else: + tmp_out = None + a = _fft( + a, + norm=norm, + out=tmp_out, + forward=forward, + # TODO: in-place FFT is only implemented for c2c, see SAT-7154 + in_place=in_place and c2c, + c2c=c2c, + axes=a_chunk, + ) -def _compute_result(dsc, a, out, forward, c2c, a_strides): + return a + + a = _truncate_or_pad(a, s, axes) + if a.size == 0: + return dpnp.get_result_array(a, out=out, casting="same_kind") + + return _fft( + a, + norm=norm, + out=out, + forward=forward, + # TODO: in-place FFT is only implemented for c2c, see SAT-7154 + in_place=in_place and c2c, + c2c=c2c, + axes=axes, + batch_fft=batch_fft, + ) + + +def _compute_result(dsc, a, out, forward, c2c, out_strides): """Compute the result of the FFT.""" exec_q = a.sycl_queue @@ -119,7 +179,7 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides): else: if ( out is not None - and out.strides == a_strides + and out.strides == tuple(out_strides) and not ti._array_overlap(a_usm, dpnp.get_usm_ndarray(out)) ): res_usm = dpnp.get_usm_ndarray(out) @@ -150,7 +210,7 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides): result = dpnp_array( out_shape, dtype=out_dtype, - strides=a_strides, + strides=out_strides, usm_type=a.usm_type, sycl_queue=exec_q, ) @@ -165,7 +225,6 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides): return result -# TODO: c2r keyword is place holder for irfftn def _cook_nd_args(a, s=None, axes=None, c2r=False): if s is None: shapeless = True @@ -315,12 +374,13 @@ def _extract_axes_chunk(a, s, chunk_size=3): return a_chunks[::-1], s_chunks[::-1] -def _fft(a, norm, out, forward, in_place, c2c, axes=None): +def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True): """Calculates FFT of the input array along the specified axes.""" index = 0 - if axes is not None: # batch_fft - len_axes = 1 if isinstance(axes, int) else len(axes) + fft_1d = isinstance(axes, int) + if batch_fft: + len_axes = 1 if fft_1d else len(axes) local_axes = numpy.arange(-len_axes, 0) a = dpnp.moveaxis(a, axes, local_axes) a_shape_orig = a.shape @@ -329,11 +389,13 @@ def _fft(a, norm, out, forward, in_place, c2c, axes=None): index = 1 a_strides = _standardize_strides_to_nonzero(a.strides, a.shape) - dsc = _commit_descriptor(a, in_place, c2c, a_strides, index, axes) - res = _compute_result(dsc, a, out, forward, c2c, a_strides) + dsc, out_strides = _commit_descriptor( + a, forward, in_place, c2c, a_strides, index, batch_fft + ) + res = _compute_result(dsc, a, out, forward, c2c, out_strides) res = _scale_result(res, a.shape, norm, forward, index) - if axes is not None: # batch_fft + if batch_fft: tmp_shape = a_shape_orig[:-1] + (res.shape[-1],) res = dpnp.reshape(res, tmp_shape) res = dpnp.moveaxis(res, local_axes, axes) @@ -369,9 +431,6 @@ def _scale_result(res, a_shape, norm, forward, index): def _truncate_or_pad(a, shape, axes): """Truncating or zero-padding the input array along the specified axes.""" - shape = (shape,) if isinstance(shape, int) else shape - axes = (axes,) if isinstance(axes, int) else axes - for s, axis in zip(shape, axes): a_shape = list(a.shape) index = [slice(None)] * a.ndim @@ -408,7 +467,7 @@ def _truncate_or_pad(a, shape, axes): return a -def _validate_out_keyword(a, out, s, axes, c2r, r2c): +def _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c): """Validate out keyword argument.""" if out is not None: dpnp.check_supported_arrays_type(out) @@ -423,10 +482,14 @@ def _validate_out_keyword(a, out, s, axes, c2r, r2c): # validate out shape against the final shape, # intermediate shapes may vary expected_shape = list(a.shape) - for s_i, axis in zip(s[::-1], axes[::-1]): - expected_shape[axis] = s_i if r2c: - expected_shape[axes[-1]] = expected_shape[axes[-1]] // 2 + 1 + expected_shape[axes[-1]] = s[-1] // 2 + 1 + elif c2c: + expected_shape[axes[-1]] = s[-1] + for s_i, axis in zip(s[-2::-1], axes[-2::-1]): + expected_shape[axis] = s_i + if c2r: + expected_shape[axes[-1]] = s[-1] if out.shape != tuple(expected_shape): raise ValueError( @@ -494,8 +557,8 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): raise ValueError(f"Invalid number of FFT data points ({n}) specified") _check_norm(norm) - a = _truncate_or_pad(a, n, axis) - _validate_out_keyword(a, out, (n,), (axis,), c2r, r2c) + a = _truncate_or_pad(a, (n,), (axis,)) + _validate_out_keyword(a, out, (n,), (axis,), c2c, c2r, r2c) # if input array is copied, in-place FFT can be used a, in_place = _copy_array(a, c2c or c2r) if not in_place and out is not None: @@ -505,9 +568,6 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): if a.size == 0: return dpnp.get_result_array(a, out=out, casting="same_kind") - # non-batch FFT - axis = None if a_ndim == 1 else axis - return _fft( a, norm=norm, @@ -517,16 +577,20 @@ def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): in_place=in_place and c2c, c2c=c2c, axes=axis, + batch_fft=a_ndim != 1, ) -def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None): +def dpnp_fftn(a, forward, real, s=None, axes=None, norm=None, out=None): """Calculates N-D FFT of the input array along axes""" - _check_norm(norm) - if isinstance(axes, (list, tuple)) and len(axes) == 0: + if isinstance(axes, Sequence) and len(axes) == 0: + if real: + raise IndexError("Empty axes.") + return a + _check_norm(norm) if a.ndim == 0: if axes is not None: raise IndexError( @@ -535,54 +599,75 @@ def dpnp_fftn(a, forward, s=None, axes=None, norm=None, out=None): return a + c2c = not real # complex-to-complex FFT + r2c = real and forward # real-to-complex FFT + c2r = real and not forward # complex-to-real FFT + if r2c and dpnp.issubdtype(a.dtype, dpnp.complexfloating): + raise TypeError("Input array must be real") + _validate_s_axes(a, s, axes) - s, axes = _cook_nd_args(a, s, axes) - # TODO: False and False are place holder for future development of - # rfft2, irfft2, rfftn, irfftn - _validate_out_keyword(a, out, s, axes, False, False) - # TODO: True is place holder for future development of - # rfft2, irfft2, rfftn, irfftn - a, in_place = _copy_array(a, True) + s, axes = _cook_nd_args(a, s, axes, c2r) + _validate_out_keyword(a, out, s, axes, c2c, c2r, r2c) + a, in_place = _copy_array(a, c2c or c2r) len_axes = len(axes) - # OneMKL supports up to 3-dimensional FFT on GPU - # repeated axis in OneMKL FFT is not allowed - if len_axes > 3 or len(set(axes)) < len_axes: - axes_chunk, shape_chunk = _extract_axes_chunk(axes, s, chunk_size=3) - for s_chunk, a_chunk in zip(shape_chunk, axes_chunk): - a = _truncate_or_pad(a, shape=s_chunk, axes=a_chunk) - if out is not None and out.shape == a.shape: - tmp_out = out - else: - tmp_out = None - a = _fft( - a, - norm=norm, - out=tmp_out, - forward=forward, - in_place=in_place, - # TODO: c2c=True is place holder for future development of - # rfft2, irfft2, rfftn, irfftn - c2c=True, - axes=a_chunk, - ) - return a + if len_axes == 1: + a = _truncate_or_pad(a, (s[-1],), (axes[-1],)) + return _fft( + a, norm, out, forward, in_place and c2c, c2c, axes[0], a.ndim != 1 + ) - a = _truncate_or_pad(a, s, axes) - if a.size == 0: - return dpnp.get_result_array(a, out=out, casting="same_kind") - if a.ndim == len_axes: - # non-batch FFT - axes = None + if r2c: + # a 1D real-to-complext FFT is performed on the last axis and then + # an N-D complex-to-complex FFT over the remaining axes + a = _truncate_or_pad(a, (s[-1],), (axes[-1],)) + a = _fft( + a, + norm=norm, + # if out is used in an intermediate step, it will have memory + # overlap with input and cannot be used in the final step (a new + # result array will be created for the final step), so there is no + # benefit in using out in an intermediate step + out=None, + forward=forward, + in_place=in_place and c2c, + c2c=c2c, + axes=axes[-1], + batch_fft=a.ndim != 1, + ) + return _complex_nd_fft( + a, + s=s, + norm=norm, + out=out, + forward=forward, + in_place=in_place, + c2c=True, + axes=axes[:-1], + batch_fft=a.ndim != len_axes - 1, + ) - return _fft( - a, - norm=norm, - out=out, - forward=forward, - in_place=in_place, - # TODO: c2c=True is place holder for future development of - # rfft2, irfft2, rfftn, irfftn - c2c=True, - axes=axes, + if c2r: + # an N-D complex-to-complex FFT is performed on all axes except the + # last one then a 1D complex-to-real FFT is performed on the last axis + a = _complex_nd_fft( + a, + s=s, + norm=norm, + # out has real dtype and cannot be used in intermediate steps + out=None, + forward=forward, + in_place=in_place, + c2c=True, + axes=axes[:-1], + batch_fft=a.ndim != len_axes - 1, + ) + a = _truncate_or_pad(a, (s[-1],), (axes[-1],)) + return _fft( + a, norm, out, forward, in_place and c2c, c2c, axes[-1], a.ndim != 1 + ) + + # c2c + return _complex_nd_fft( + a, s, norm, out, forward, in_place, c2c, axes, a.ndim != len_axes ) diff --git a/tests/test_fft.py b/tests/test_fft.py index 3ded3b091026..08e34efcc05a 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -12,6 +12,7 @@ assert_dtype_allclose, get_all_dtypes, get_complex_dtypes, + get_float_complex_dtypes, get_float_dtypes, ) @@ -42,7 +43,9 @@ class TestFft: def setup_method(self): numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) @pytest.mark.parametrize( "shape", [(64,), (8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)] ) @@ -59,7 +62,9 @@ def test_fft_ndim(self, dtype, shape, norm): dpnp_res = dpnp.fft.ifft(dpnp_data, norm=norm) assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_complex=True) + ) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) def test_fft_1D(self, dtype, n, norm): @@ -361,7 +366,9 @@ class TestFft2: def setup_method(self): numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_complex=True) + ) def test_fft2(self, dtype): x1 = numpy.random.uniform(-10, 10, 24) a_np = numpy.array(x1, dtype=dtype).reshape(2, 3, 4) @@ -521,16 +528,18 @@ def test_fftn_out(self, axes, s): out_shape = list(a.shape) for s_i, axis in zip(s[::-1], axes[::-1]): out_shape[axis] = s_i - result = dpnp.empty(out_shape, dtype=a.dtype) - dpnp.fft.fftn(a, out=result, s=s, axes=axes) + out = dpnp.empty(out_shape, dtype=a.dtype) + result = dpnp.fft.fftn(a, out=out, s=s, axes=axes) + assert out is result # Intel® NumPy ignores repeated axes, handle it one by one expected = a_np for jj, ii in zip(s[::-1], axes[::-1]): expected = numpy.fft.fft(expected, n=jj, axis=ii) assert_dtype_allclose(result, expected, check_only_type_kind=True) - iresult = dpnp.empty(out_shape, dtype=a.dtype) - dpnp.fft.ifftn(result, out=iresult, s=s, axes=axes) + out = dpnp.empty(out_shape, dtype=a.dtype) + iresult = dpnp.fft.ifftn(result, out=out, s=s, axes=axes) + assert out is iresult iexpected = expected for jj, ii in zip(s[::-1], axes[::-1]): iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii) @@ -557,7 +566,7 @@ def test_fftn_empty_array(self): expected = numpy.fft.fftn(a_np, axes=(0, 1, 2), s=(5, 2, 4)) assert_dtype_allclose(result, expected, check_only_type_kind=True) - @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_fftn_0D(self, dtype): a = dpnp.array(3, dtype=dtype) # 0-D input @@ -580,7 +589,7 @@ def test_fftn_0D(self, dtype): # IndexError, while Intel® NumPy raises ZeroDivisionError assert_raises(IndexError, dpnp.fft.fftn, a, axes=(0,)) - @pytest.mark.parametrize("dtype", get_all_dtypes()) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) def test_fftn_empty_axes(self, dtype): a = dpnp.ones((2, 3, 4), dtype=dtype) @@ -633,7 +642,9 @@ class TestHfft: def setup_method(self): numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_complex=True) + ) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) def test_hfft_1D(self, dtype, n, norm): @@ -662,7 +673,7 @@ def test_hfft_1D_complex(self, dtype, n, norm): assert_dtype_allclose(result, expected, check_only_type_kind=True) @pytest.mark.parametrize( - "dtype", get_all_dtypes(no_bool=True, no_complex=True) + "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) ) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) @@ -690,7 +701,9 @@ class TestIrfft: def setup_method(self): numpy.random.seed(42) - @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_complex=True) + ) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) def test_fft_1D(self, dtype, n, norm): @@ -834,7 +847,7 @@ def setup_method(self): numpy.random.seed(42) @pytest.mark.parametrize( - "dtype", get_all_dtypes(no_bool=True, no_complex=True) + "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) ) @pytest.mark.parametrize( "shape", [(64,), (8, 8), (4, 16), (4, 4, 4), (2, 4, 4, 2)] @@ -849,7 +862,7 @@ def test_fft_rfft(self, dtype, shape): assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize( - "dtype", get_all_dtypes(no_bool=True, no_complex=True) + "dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True) ) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) @@ -964,3 +977,192 @@ def test_fft_validate_out(self): a = dpnp.ones((10,), dtype=dpnp.float32) out = dpnp.empty((10,), dtype=dpnp.complex64) assert_raises(ValueError, dpnp.fft.rfft, a, out=out) + + +class TestRfft2: + def setup_method(self): + numpy.random.seed(42) + + # TODO: add other axes when mkl_fft gh-119 is addressed + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_none=True, no_complex=True) + ) + @pytest.mark.parametrize("axes", [(0, 1)]) # (1, 2),(0, 2),(2, 1),(2, 0) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_rfft2(self, dtype, axes, norm, order): + x1 = numpy.random.uniform(-10, 10, 24) + a_np = numpy.array(x1, dtype=dtype).reshape(2, 3, 4, order=order) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfft2(a, axes=axes, norm=norm) + expected = numpy.fft.rfft2(a_np, axes=axes, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + s = (a.shape[axes[0]], a.shape[axes[1]]) + result = dpnp.fft.irfft2(result, s=s, axes=axes, norm=norm) + expected = numpy.fft.irfft2(expected, s=s, axes=axes, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_none=True)) + def test_irfft2(self, dtype): + # x1 is Hermitian symmetric + x1 = numpy.array([[0, 1, 2], [5, 4, 6], [5, 7, 6]]) + a_np = numpy.array(x1, dtype=dtype) + a = dpnp.asarray(a_np) + + result = dpnp.fft.irfft2(a) + expected = numpy.fft.irfft2(a_np) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("s", [None, (3, 3), (10, 10), (3, 10)]) + def test_rfft2_s(self, s): + x1 = numpy.random.uniform(-10, 10, 48) + a_np = numpy.array(x1, dtype=numpy.float32).reshape(6, 8) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfft2(a, s=s) + expected = numpy.fft.rfft2(a_np, s=s) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + result = dpnp.fft.irfft2(result, s=s) + expected = numpy.fft.irfft2(expected, s=s) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_fft_error(self, xp): + a = xp.ones((2, 3)) + # empty axes + assert_raises(IndexError, xp.fft.rfft2, a, axes=()) + + a = xp.ones((2, 3), dtype=xp.complex64) + # Input array must be real + # Stock NumPy 2.0 raises TypeError + # while stock NumPy 1.26 ignores imaginary part + if xp == dpnp: + assert_raises(TypeError, xp.fft.rfft2, a) + + +class TestRfftn: + def setup_method(self): + numpy.random.seed(42) + + # TODO: add additional axes when mkl_fft gh-119 is addressed + @pytest.mark.parametrize("dtype", get_float_dtypes()) + @pytest.mark.parametrize( + "axes", [(0, 1, 2), (-2, -4, -1, -3)] # (-1, -4, -2) + ) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_rfftn(self, dtype, axes, norm, order): + x1 = numpy.random.uniform(-10, 10, 120) + a_np = numpy.array(x1, dtype=dtype).reshape(2, 3, 4, 5, order=order) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfftn(a, axes=axes, norm=norm) + expected = numpy.fft.rfftn(a_np, axes=axes, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + s = [] + for axis in axes: + s.append(a.shape[axis]) + iresult = dpnp.fft.irfftn(result, s=s, axes=axes, norm=norm) + iexpected = numpy.fft.irfftn(expected, s=s, axes=axes, norm=norm) + assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True) + + @pytest.mark.parametrize( + "axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)] + ) + def test_rfftn_repeated_axes(self, axes): + x1 = numpy.random.uniform(-10, 10, 120) + a_np = numpy.array(x1, dtype=numpy.float32).reshape(2, 3, 4, 5) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfftn(a, axes=axes) + # Intel® NumPy ignores repeated axes, handle it one by one + expected = numpy.fft.rfft(a_np, axis=axes[-1]) + # need to pass shape for c2c FFT since expected and a_np + # do not have the same shape after calling rfft + shape = [] + for axis in axes: + shape.append(a_np.shape[axis]) + for jj, ii in zip(shape[-2::-1], axes[-2::-1]): + expected = numpy.fft.fft(expected, n=jj, axis=ii) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + iresult = dpnp.fft.irfftn(result, axes=axes) + iexpected = expected + for ii in axes[-2::-1]: + iexpected = numpy.fft.ifft(iexpected, axis=ii) + iexpected = numpy.fft.irfft(iexpected, axis=axes[-1]) + assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True) + + @pytest.mark.parametrize("axes", [(2, 3, 3, 2), (0, 0, 3, 3)]) + @pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 9)]) + def test_rfftn_repeated_axes_with_s(self, axes, s): + x1 = numpy.random.uniform(-10, 10, 120) + a_np = numpy.array(x1, dtype=numpy.float32).reshape(2, 3, 4, 5) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfftn(a, s=s, axes=axes) + # Intel® NumPy ignores repeated axes, handle it one by one + expected = numpy.fft.rfft(a_np, n=s[-1], axis=axes[-1]) + for jj, ii in zip(s[-2::-1], axes[-2::-1]): + expected = numpy.fft.fft(expected, n=jj, axis=ii) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + iresult = dpnp.fft.irfftn(result, s=s, axes=axes) + iexpected = expected + for jj, ii in zip(s[-2::-1], axes[-2::-1]): + iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii) + iexpected = numpy.fft.irfft(iexpected, n=s[-1], axis=axes[-1]) + assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True) + + @pytest.mark.parametrize("axes", [(0, 1, 2, 3), (1, 2, 1, 2), (2, 2, 2, 3)]) + @pytest.mark.parametrize("s", [(2, 3, 4, 5), (5, 6, 7, 9), (2, 5, 1, 2)]) + def test_rfftn_out(self, axes, s): + x1 = numpy.random.uniform(-10, 10, 120) + a_np = numpy.array(x1, dtype=numpy.float32).reshape(2, 3, 4, 5) + a = dpnp.asarray(a_np) + + out_shape = list(a.shape) + out_shape[axes[-1]] = s[-1] // 2 + 1 + for s_i, axis in zip(s[-2::-1], axes[-2::-1]): + out_shape[axis] = s_i + out = dpnp.empty(out_shape, dtype=numpy.complex64) + + result = dpnp.fft.rfftn(a, out=out, s=s, axes=axes) + assert out is result + # Intel® NumPy ignores repeated axes, handle it one by one + expected = numpy.fft.rfft(a_np, n=s[-1], axis=axes[-1]) + for jj, ii in zip(s[-2::-1], axes[-2::-1]): + expected = numpy.fft.fft(expected, n=jj, axis=ii) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + out_shape = list(a.shape) + for s_i, axis in zip(s[-2::-1], axes[-2::-1]): + out_shape[axis] = s_i + out_shape[axes[-1]] = s[-1] + out = dpnp.empty(out_shape, dtype=numpy.float32) + + iresult = dpnp.fft.irfftn(result, out=out, s=s, axes=axes) + assert out is iresult + + iexpected = expected + for jj, ii in zip(s[-2::-1], axes[-2::-1]): + iexpected = numpy.fft.ifft(iexpected, n=jj, axis=ii) + iexpected = numpy.fft.irfft(iexpected, n=s[-1], axis=axes[-1]) + assert_dtype_allclose(iresult, iexpected, check_only_type_kind=True) + + def test_rfftn_1d_array(self): + x1 = numpy.random.uniform(-10, 10, 20) + a_np = numpy.array(x1, dtype=numpy.float32) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfftn(a) + expected = numpy.fft.rfftn(a_np) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + result = dpnp.fft.irfftn(a) + expected = numpy.fft.irfftn(a_np) + assert_dtype_allclose(result, expected, check_only_type_kind=True) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index f520a0f19282..89b50151421f 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1289,22 +1289,53 @@ def test_fft(func, device): assert_sycl_queue_equal(result_queue, expected_queue) -@pytest.mark.parametrize("func", ["fftn", "ifftn"]) @pytest.mark.parametrize( "device", valid_devices, ids=[device.filter_string for device in valid_devices], ) -def test_fftn(func, device): - data = numpy.arange(24, dtype=numpy.complex128).reshape(2, 3, 4) +def test_fftn(device): + data = numpy.arange(24, dtype=numpy.complex64).reshape(2, 3, 4) dpnp_data = dpnp.array(data, device=device) - expected = getattr(numpy.fft, func)(data) - result = getattr(dpnp.fft, func)(dpnp_data) - assert_dtype_allclose(result, expected) + expected = numpy.fft.fftn(data) + result = dpnp.fft.fftn(dpnp_data) + assert_dtype_allclose(result, expected, check_only_type_kind=True) - expected_queue = dpnp_data.get_array().sycl_queue - result_queue = result.get_array().sycl_queue + expected_queue = dpnp_data.sycl_queue + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + expected = numpy.fft.ifftn(expected) + result = dpnp.fft.ifftn(result) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + +@pytest.mark.parametrize( + "device", + valid_devices, + ids=[device.filter_string for device in valid_devices], +) +def test_rfftn(device): + data = numpy.arange(24, dtype=numpy.float32).reshape(2, 3, 4) + dpnp_data = dpnp.array(data, device=device) + + expected = numpy.fft.rfftn(data) + result = dpnp.fft.rfftn(dpnp_data) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + expected_queue = dpnp_data.sycl_queue + result_queue = result.sycl_queue + assert_sycl_queue_equal(result_queue, expected_queue) + + expected = numpy.fft.irfftn(expected) + result = dpnp.fft.irfftn(result) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + result_queue = result.sycl_queue assert_sycl_queue_equal(result_queue, expected_queue) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 7e2dfd305185..3ba9585655ec 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -985,15 +985,27 @@ def test_fft(func, usm_type): assert result.usm_type == usm_type -@pytest.mark.parametrize("func", ["fftn", "ifftn"]) @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) -def test_fftn(func, usm_type): - dpnp_data = dp.arange(24, usm_type=usm_type, dtype=dp.complex64).reshape( - 2, 3, 4 - ) - result = getattr(dp.fft, func)(dpnp_data) +def test_fftn(usm_type): + dpnp_data = dp.arange(24, usm_type=usm_type).reshape(2, 3, 4) + assert dpnp_data.usm_type == usm_type + + result = dp.fft.fftn(dpnp_data) + assert result.usm_type == usm_type + + result = dp.fft.ifftn(result) + assert result.usm_type == usm_type + +@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) +def test_rfftn(usm_type): + dpnp_data = dp.arange(24, usm_type=usm_type).reshape(2, 3, 4) assert dpnp_data.usm_type == usm_type + + result = dp.fft.rfftn(dpnp_data) + assert result.usm_type == usm_type + + result = dp.fft.irfftn(result) assert result.usm_type == usm_type diff --git a/tests/third_party/cupy/fft_tests/test_fft.py b/tests/third_party/cupy/fft_tests/test_fft.py index a23d042f9dc3..10822b30422b 100644 --- a/tests/third_party/cupy/fft_tests/test_fft.py +++ b/tests/third_party/cupy/fft_tests/test_fft.py @@ -129,7 +129,6 @@ def test_ifft(self, xp, dtype): {"shape": (3, 4), "s": None, "axes": (-2, -1)}, {"shape": (3, 4), "s": None, "axes": (-1, -2)}, # {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109 - {"shape": (3, 4), "s": None, "axes": None}, # {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (1, 4, 4), "axes": (0, 1, 2)}, @@ -137,7 +136,6 @@ def test_ifft(self, xp, dtype): {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, # mkl_fft gh-109 - {"shape": (2, 3, 4), "s": None, "axes": None}, # {"shape": (2, 3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 # {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, # mkl_fft gh-109 {"shape": (2, 3, 4, 5), "s": None, "axes": None}, @@ -219,7 +217,6 @@ def test_ifft2(self, xp, dtype, order): {"shape": (3, 4), "s": None, "axes": [-1, -2]}, # {"shape": (3, 4), "s": None, "axes": (0,)}, # mkl_fft gh-109 # {"shape": (3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 - {"shape": (3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": None, "axes": None}, {"shape": (2, 3, 4), "s": (1, 4, 4), "axes": (0, 1, 2)}, {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (0, 1, 2)}, @@ -227,7 +224,6 @@ def test_ifft2(self, xp, dtype, order): {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # {"shape": (2, 3, 4), "s": None, "axes": (-1, -3)}, # mkl_fft gh-109 # {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, # mkl_fft gh-109 - {"shape": (2, 3, 4), "s": None, "axes": None}, # {"shape": (2, 3, 4), "s": None, "axes": ()}, # mkl_fft gh-108 # {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, # mkl_fft gh-109 {"shape": (2, 3, 4), "s": (4, 3, 2), "axes": (2, 0, 1)}, @@ -338,11 +334,83 @@ def test_irfft(self, xp, dtype): return out +@pytest.mark.usefixtures("skip_forward_backward") +@testing.parameterize( + *( + testing.product_dict( + [ + # some of the following cases are modified, since in NumPy 2.0.0 + # `s` must contain only integer `s`, not None values, and + # If `s` is not None, `axes`` must not be None either. + {"shape": (3, 4), "s": None, "axes": None}, + {"shape": (3, 4), "s": (1, 4), "axes": (0, 1)}, + {"shape": (3, 4), "s": (1, 5), "axes": (0, 1)}, + {"shape": (3, 4), "s": None, "axes": (-2, -1)}, + {"shape": (3, 4), "s": None, "axes": (-1, -2)}, + {"shape": (3, 4), "s": None, "axes": (0,)}, + # {"shape": (2, 3, 4), "s": None, "axes": None}, # mkl_fft gh-116 + # {"shape": (2, 3, 4), "s": (1, 4, 4), "axes": (0, 1, 2)}, # mkl_fft gh-115 + # {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (0, 1, 2)}, # mkl_fft gh-115 + # {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, # mkl_fft gh-116 + # {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # mkl_fft gh-116 + {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, + {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, + # {"shape": (2, 3, 4, 5), "s": None, "axes": None}, # mkl_fft gh-109 and gh-116 + ], + testing.product( + {"norm": [None, "backward", "ortho", "forward", ""]} + ), + ) + ) +) +class TestRfft2: + @testing.for_orders("CF") + @testing.for_all_dtypes(no_complex=True) + @testing.numpy_cupy_allclose( + rtol=1e-4, + atol=1e-7, + accept_error=ValueError, + contiguous_check=False, + type_check=has_support_aspect64(), + ) + def test_rfft2(self, xp, dtype, order): + a = testing.shaped_random(self.shape, xp, dtype) + if order == "F": + a = xp.asfortranarray(a) + out = xp.fft.rfft2(a, s=self.s, axes=self.axes, norm=self.norm) + + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.complex64) + + return out + + @testing.for_orders("CF") + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose( + rtol=1e-4, + atol=1e-7, + accept_error=ValueError, + contiguous_check=False, + type_check=has_support_aspect64(), + ) + def test_irfft2(self, xp, dtype, order): + if self.s is None and self.axes in [None, (-2, -1)]: + pytest.skip("Input is not Hermitian Symmetric") + a = testing.shaped_random(self.shape, xp, dtype) + if order == "F": + a = xp.asfortranarray(a) + out = xp.fft.irfft2(a, s=self.s, axes=self.axes, norm=self.norm) + + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.float32) + + return out + + @testing.parameterize( {"shape": (3, 4), "s": None, "axes": (), "norm": None}, {"shape": (2, 3, 4), "s": None, "axes": (), "norm": None}, ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestRfft2EmptyAxes: @testing.for_all_dtypes(no_complex=True) def test_rfft2(self, dtype): @@ -359,11 +427,83 @@ def test_irfft2(self, dtype): xp.fft.irfft2(a, s=self.s, axes=self.axes, norm=self.norm) +@pytest.mark.usefixtures("skip_forward_backward") +@testing.parameterize( + *( + testing.product_dict( + [ + # some of the following cases are modified, since in NumPy 2.0.0 + # `s` must contain only integer `s`, not None values, and + # If `s` is not None, `axes`` must not be None either. + {"shape": (3, 4), "s": None, "axes": None}, + {"shape": (3, 4), "s": (1, 4), "axes": (0, 1)}, + {"shape": (3, 4), "s": (1, 5), "axes": (0, 1)}, + {"shape": (3, 4), "s": None, "axes": (-2, -1)}, + {"shape": (3, 4), "s": None, "axes": (-1, -2)}, + {"shape": (3, 4), "s": None, "axes": (0,)}, + # {"shape": (2, 3, 4), "s": None, "axes": None}, # mkl_fft gh-116 + # {"shape": (2, 3, 4), "s": (1, 4, 4), "axes": (0, 1, 2)}, # mkl_fft gh-115 + # {"shape": (2, 3, 4), "s": (1, 4, 10), "axes": (0, 1, 2)}, # mkl_fft gh-115 + # {"shape": (2, 3, 4), "s": None, "axes": (-3, -2, -1)}, # mkl_fft gh-116 + # {"shape": (2, 3, 4), "s": None, "axes": (-1, -2, -3)}, # mkl_fft gh-116 + {"shape": (2, 3, 4), "s": None, "axes": (0, 1)}, + {"shape": (2, 3, 4), "s": (2, 3), "axes": (0, 1, 2)}, + # {"shape": (2, 3, 4, 5), "s": None, "axes": None}, # mkl_fft gh-109 and gh-116 + ], + testing.product( + {"norm": [None, "backward", "ortho", "forward", ""]} + ), + ) + ) +) +class TestRfftn: + @testing.for_orders("CF") + @testing.for_all_dtypes(no_complex=True) + @testing.numpy_cupy_allclose( + rtol=1e-4, + atol=1e-7, + accept_error=ValueError, + contiguous_check=False, + type_check=has_support_aspect64(), + ) + def test_rfftn(self, xp, dtype, order): + a = testing.shaped_random(self.shape, xp, dtype) + if order == "F": + a = xp.asfortranarray(a) + out = xp.fft.rfftn(a, s=self.s, axes=self.axes, norm=self.norm) + + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.complex64) + + return out + + @testing.for_orders("CF") + @testing.for_all_dtypes() + @testing.numpy_cupy_allclose( + rtol=1e-4, + atol=1e-7, + accept_error=ValueError, + contiguous_check=False, + type_check=has_support_aspect64(), + ) + def test_irfftn(self, xp, dtype, order): + if self.s is None and self.axes in [None, (-2, -1)]: + pytest.skip("Input is not Hermitian Symmetric") + a = testing.shaped_random(self.shape, xp, dtype) + if order == "F": + a = xp.asfortranarray(a) + out = xp.fft.irfftn(a, s=self.s, axes=self.axes, norm=self.norm) + + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.float32) + + return out + + @testing.parameterize( {"shape": (3, 4), "s": None, "axes": (), "norm": None}, {"shape": (2, 3, 4), "s": None, "axes": (), "norm": None}, ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestRfftnEmptyAxes: @testing.for_all_dtypes(no_complex=True) def test_rfftn(self, dtype):