From b37bb2dd50045136659152d85d62bade12891ad8 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 28 Jun 2024 19:38:11 -0500 Subject: [PATCH 1/7] implement rfft and irfft --- dpnp/backend/extensions/fft/fft_py.cpp | 32 ++- dpnp/backend/extensions/fft/fft_utils.hpp | 38 ++- dpnp/backend/extensions/fft/in_place.cpp | 8 +- dpnp/backend/extensions/fft/out_of_place.cpp | 71 ++++- dpnp/backend/include/dpnp_iface_fptr.hpp | 20 +- dpnp/backend/kernels/dpnp_krnl_fft.cpp | 25 -- dpnp/dpnp_algo/dpnp_algo.pxd | 1 - dpnp/fft/dpnp_algo_fft.pyx | 60 ---- dpnp/fft/dpnp_iface_fft.py | 276 ++++++++++++------- dpnp/fft/dpnp_utils_fft.py | 128 ++++++--- tests/test_fft.py | 259 ++++++++++++++++- tests/test_sycl_queue.py | 6 +- tests/test_usm_type.py | 6 +- tests/third_party/cupy/fft_tests/test_fft.py | 15 +- 14 files changed, 660 insertions(+), 285 deletions(-) diff --git a/dpnp/backend/extensions/fft/fft_py.cpp b/dpnp/backend/extensions/fft/fft_py.cpp index bee35c553b85..0b7d8d8f2f3f 100644 --- a/dpnp/backend/extensions/fft/fft_py.cpp +++ b/dpnp/backend/extensions/fft/fft_py.cpp @@ -68,39 +68,57 @@ void register_descriptor(py::module &m, const char *name) PYBIND11_MODULE(_fft_impl, m) { constexpr mkl_dft::domain complex_dom = mkl_dft::domain::COMPLEX; + constexpr mkl_dft::domain real_dom = mkl_dft::domain::REAL; constexpr mkl_dft::precision single_prec = mkl_dft::precision::SINGLE; constexpr mkl_dft::precision double_prec = mkl_dft::precision::DOUBLE; register_descriptor(m, "Complex64Descriptor"); register_descriptor(m, "Complex128Descriptor"); + register_descriptor(m, "Real32Descriptor"); + register_descriptor(m, "Real64Descriptor"); - // out-of-place c2c FFT, both SINGLE and DOUBLE precisions are supported - // with overloading of "_fft_out_of_place" function on python side - m.def("_fft_out_of_place", + // out-of-place FFT, all possible combination (single/double precisions and + // real/complex domains) are supported with overloading of + // "_fft_out_of_place" function on python side + m.def("_fft_out_of_place", // single precision c2c out-of-place FFT &fft_ns::compute_fft_out_of_place, "Compute out-of-place complex-to-complex fft using OneMKL DFT " "library for complex64 data types.", py::arg("descriptor"), py::arg("input"), py::arg("output"), py::arg("is_forward"), py::arg("depends") = py::list()); - m.def("_fft_out_of_place", + m.def("_fft_out_of_place", // double precision c2c out-of-place FFT &fft_ns::compute_fft_out_of_place, "Compute out-of-place complex-to-complex fft using OneMKL DFT " "library for complex128 data types.", py::arg("descriptor"), py::arg("input"), py::arg("output"), py::arg("is_forward"), py::arg("depends") = py::list()); - // in-place c2c FFT, both SINGLE and DOUBLE precisions are supported with + m.def("_fft_out_of_place", // single precision r2c/c2r out-of-place FFT + &fft_ns::compute_fft_out_of_place, + "Compute out-of-place real-to-complex fft using OneMKL DFT library " + "for float32 data types.", + py::arg("descriptor"), py::arg("input"), py::arg("output"), + py::arg("is_forward"), py::arg("depends") = py::list()); + + m.def("_fft_out_of_place", // double precision r2c/c2r out-of-place FFT + &fft_ns::compute_fft_out_of_place, + "Compute out-of-place real-to-complex fft using OneMKL DFT library " + "for float64 data types.", + py::arg("descriptor"), py::arg("input"), py::arg("output"), + py::arg("is_forward"), py::arg("depends") = py::list()); + + // in-place c2c FFT, both single and double precisions are supported with // overloading of "_fft_in_place" function on python side - m.def("_fft_in_place", + m.def("_fft_in_place", // single precision c2c in-place FFT &fft_ns::compute_fft_in_place, "Compute in-place complex-to-complex fft using OneMKL DFT library " "for complex64 data types.", py::arg("descriptor"), py::arg("input-output"), py::arg("is_forward"), py::arg("depends") = py::list()); - m.def("_fft_in_place", + m.def("_fft_in_place", // double precision c2c in-place FFT &fft_ns::compute_fft_in_place, "Compute in-place complex-to-complex fft using OneMKL DFT library " "for complex128 data types.", diff --git a/dpnp/backend/extensions/fft/fft_utils.hpp b/dpnp/backend/extensions/fft/fft_utils.hpp index 790353de74f7..cb25eb4ac949 100644 --- a/dpnp/backend/extensions/fft/fft_utils.hpp +++ b/dpnp/backend/extensions/fft/fft_utils.hpp @@ -31,21 +31,43 @@ namespace dpnp::extensions::fft { namespace mkl_dft = oneapi::mkl::dft; -template +template struct ScaleType { - using value_type = void; + using type_in = void; + using type_out = void; +}; + +// for r2c FFT, type_in is real and type_out is complex +// is_forward is true +template +struct ScaleType +{ + using prec_type = typename std:: + conditional::type; + using type_in = prec_type; + using type_out = std::complex; }; -template <> -struct ScaleType +// for c2r FFT, type_in is complex and type_out is real +// is_forward is false +template +struct ScaleType { - using value_type = float; + using prec_type = typename std:: + conditional::type; + using type_in = std::complex; + using type_out = prec_type; }; -template <> -struct ScaleType +// for c2c FFT, both type_in and type_out are complex +// regardless of is_fwd +template +struct ScaleType { - using value_type = double; + using prec_type = typename std:: + conditional::type; + using type_in = std::complex; + using type_out = std::complex; }; } // namespace dpnp::extensions::fft diff --git a/dpnp/backend/extensions/fft/in_place.cpp b/dpnp/backend/extensions/fft/in_place.cpp index fde29906b678..9256d022efcc 100644 --- a/dpnp/backend/extensions/fft/in_place.cpp +++ b/dpnp/backend/extensions/fft/in_place.cpp @@ -67,8 +67,10 @@ std::pair dpctl::tensor::validation::CheckWritable::throw_if_not_writable(in_out); - using ScaleT = typename ScaleType::value_type; - std::complex *in_out_ptr = in_out.get_data>(); + // in-place is only used for c2c FFT at this time, passing true or false is + // indifferent + using ScaleT = typename ScaleType::type_in; + ScaleT *in_out_ptr = in_out.get_data(); sycl::event fft_event = {}; std::stringstream error_msg; @@ -104,6 +106,7 @@ std::pair } // Explicit instantiations +// single precision c2c FFT template std::pair compute_fft_in_place( DescriptorWrapper &descr, @@ -111,6 +114,7 @@ template std::pair compute_fft_in_place( const bool is_forward, const std::vector &depends); +// double precision c2c FFT template std::pair compute_fft_in_place( DescriptorWrapper &descr, diff --git a/dpnp/backend/extensions/fft/out_of_place.cpp b/dpnp/backend/extensions/fft/out_of_place.cpp index 06f014d49a7d..70b4164387b0 100644 --- a/dpnp/backend/extensions/fft/out_of_place.cpp +++ b/dpnp/backend/extensions/fft/out_of_place.cpp @@ -84,20 +84,44 @@ std::pair "execution queue of the descriptor."); } - py::ssize_t in_size = in.get_size(); - py::ssize_t out_size = out.get_size(); - if (in_size != out_size) { - throw py::value_error("The size of the input vector must be " - "equal to the size of the output vector."); + const py::ssize_t *in_shape = in.get_shape_raw(); + const py::ssize_t *out_shape = out.get_shape_raw(); + const std::int64_t m = in_shape[in_nd - 1]; + const std::int64_t n = out_shape[out_nd - 1]; + + std::int64_t in_size = 1; + if (in_nd > 1) { + for (int i = 0; i < in_nd - 1; ++i) { + if (in_shape[i] != out_shape[i]) { + throw py::value_error("The shape of the input and output " + "arrays must be the same."); + } + in_size *= in_shape[i]; + } } - size_t src_nelems = in_size; - dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out); - dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, src_nelems); + std::int64_t N; + if (dom == mkl_dft::domain::REAL && is_forward) { + // r2c FFT + N = m / 2 + 1; // integer divide + if (n != N) { + throw py::value_error("The shape of the output array is not " + "correct for real to complex transform."); + } + } + else { + // c2c and c2r FFT. For c2r FFT, input is zero-padded in python side to + // have the same size as output before calling this function + N = m; + if (n != N) { + throw py::value_error("The shape of the input array must be " + "the same as the shape of the output array."); + } + } - using ScaleT = typename ScaleType::value_type; - std::complex *in_ptr = in.get_data>(); - std::complex *out_ptr = out.get_data>(); + const std::size_t n_elems = in_size * N; + dpctl::tensor::validation::CheckWritable::throw_if_not_writable(out); + dpctl::tensor::validation::AmpleMemory::throw_if_not_ample(out, n_elems); sycl::event fft_event = {}; std::stringstream error_msg; @@ -105,10 +129,18 @@ std::pair try { if (is_forward) { + using ScaleT_in = typename ScaleType::type_in; + using ScaleT_out = typename ScaleType::type_out; + ScaleT_in *in_ptr = in.get_data(); + ScaleT_out *out_ptr = out.get_data(); fft_event = mkl_dft::compute_forward(descr.get_descriptor(), in_ptr, out_ptr, depends); } else { + using ScaleT_in = typename ScaleType::type_in; + using ScaleT_out = typename ScaleType::type_out; + ScaleT_in *in_ptr = in.get_data(); + ScaleT_out *out_ptr = out.get_data(); fft_event = mkl_dft::compute_backward(descr.get_descriptor(), in_ptr, out_ptr, depends); } @@ -133,6 +165,7 @@ std::pair } // Explicit instantiations +// single precision c2c FFT template std::pair compute_fft_out_of_place( DescriptorWrapper &descr, @@ -141,6 +174,7 @@ template std::pair compute_fft_out_of_place( const bool is_forward, const std::vector &depends); +// double precision c2c FFT template std::pair compute_fft_out_of_place( DescriptorWrapper &descr, @@ -149,4 +183,19 @@ template std::pair compute_fft_out_of_place( const bool is_forward, const std::vector &depends); +// single precision r2c/c2r FFT +template std::pair compute_fft_out_of_place( + DescriptorWrapper &descr, + const dpctl::tensor::usm_ndarray &in, + const dpctl::tensor::usm_ndarray &out, + const bool is_forward, + const std::vector &depends); + +// double precision r2c/c2r FFT +template std::pair compute_fft_out_of_place( + DescriptorWrapper &descr, + const dpctl::tensor::usm_ndarray &in, + const dpctl::tensor::usm_ndarray &out, + const bool is_forward, + const std::vector &depends); } // namespace dpnp::extensions::fft diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 9f9b7a89143f..79d1f18cc3c7 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -84,17 +84,15 @@ enum class DPNPFuncName : size_t DPNP_FN_DOT, /**< Used in numpy.dot() impl */ DPNP_FN_DOT_EXT, /**< Used in numpy.dot() impl, requires extra parameters */ DPNP_FN_EDIFF1D, /**< Used in numpy.ediff1d() impl */ - DPNP_FN_EDIFF1D_EXT, /**< Used in numpy.ediff1d() impl, requires extra - parameters */ - DPNP_FN_ERF, /**< Used in scipy.special.erf impl */ - DPNP_FN_ERF_EXT, /**< Used in scipy.special.erf impl, requires extra - parameters */ - DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */ - DPNP_FN_FFT_FFT_EXT, /**< Used in numpy.fft.fft() impl, requires extra - parameters */ - DPNP_FN_FFT_RFFT, /**< Used in numpy.fft.rfft() impl */ - DPNP_FN_FFT_RFFT_EXT, /**< Used in numpy.fft.rfft() impl, requires extra - parameters */ + DPNP_FN_EDIFF1D_EXT, /**< Used in numpy.ediff1d() impl, requires extra + parameters */ + DPNP_FN_ERF, /**< Used in scipy.special.erf impl */ + DPNP_FN_ERF_EXT, /**< Used in scipy.special.erf impl, requires extra + parameters */ + DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */ + DPNP_FN_FFT_FFT_EXT, /**< Used in numpy.fft.fft() impl, requires extra + parameters */ + DPNP_FN_FFT_RFFT, /**< Used in numpy.fft.rfft() impl */ DPNP_FN_INITVAL, /**< Used in numpy ones, ones_like, zeros, zeros_like impls */ DPNP_FN_INITVAL_EXT, /**< Used in numpy ones, ones_like, zeros, zeros_like diff --git a/dpnp/backend/kernels/dpnp_krnl_fft.cpp b/dpnp/backend/kernels/dpnp_krnl_fft.cpp index aec669a86990..ff4d3873c881 100644 --- a/dpnp/backend/kernels/dpnp_krnl_fft.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_fft.cpp @@ -698,20 +698,6 @@ void (*dpnp_fft_rfft_default_c)(const void *, const size_t) = dpnp_fft_rfft_c<_DataType_input, _DataType_output>; -template -DPCTLSyclEventRef (*dpnp_fft_rfft_ext_c)(DPCTLSyclQueueRef, - const void *, - void *, - const shape_elem_type *, - const shape_elem_type *, - size_t, - long, - long, - size_t, - const size_t, - const DPCTLEventVectorRef) = - dpnp_fft_rfft_c<_DataType_input, _DataType_output>; - void func_map_init_fft_func(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_INT][eft_INT] = { @@ -762,16 +748,5 @@ void func_map_init_fft_func(func_map_t &fmap) eft_C128, (void *)dpnp_fft_rfft_default_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_INT][eft_INT] = { - eft_C128, (void *)dpnp_fft_rfft_ext_c>, - eft_C64, (void *)dpnp_fft_rfft_ext_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_LNG][eft_LNG] = { - eft_C128, (void *)dpnp_fft_rfft_ext_c>, - eft_C64, (void *)dpnp_fft_rfft_ext_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_FLT][eft_FLT] = { - eft_C64, (void *)dpnp_fft_rfft_ext_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT_EXT][eft_DBL][eft_DBL] = { - eft_C128, (void *)dpnp_fft_rfft_ext_c>}; - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 3b5b23832260..c2d1e0fa3934 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -40,7 +40,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_EDIFF1D_EXT DPNP_FN_ERF_EXT DPNP_FN_FFT_FFT_EXT - DPNP_FN_FFT_RFFT_EXT DPNP_FN_MEDIAN_EXT DPNP_FN_MODF_EXT DPNP_FN_PARTITION_EXT diff --git a/dpnp/fft/dpnp_algo_fft.pyx b/dpnp/fft/dpnp_algo_fft.pyx index 6ce107df922e..ae1294591343 100644 --- a/dpnp/fft/dpnp_algo_fft.pyx +++ b/dpnp/fft/dpnp_algo_fft.pyx @@ -39,7 +39,6 @@ from dpnp.dpnp_algo cimport * __all__ = [ "dpnp_fft_deprecated", - "dpnp_rfft" ] ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_fft_fft_t)(c_dpctl.DPCTLSyclQueueRef, void *, void * , @@ -104,62 +103,3 @@ cpdef utils.dpnp_descriptor dpnp_fft_deprecated(utils.dpnp_descriptor input, c_dpctl.DPCTLEvent_Delete(event_ref) return result - - -cpdef utils.dpnp_descriptor dpnp_rfft(utils.dpnp_descriptor input, - size_t input_boundarie, - size_t output_boundarie, - long axis, - size_t inverse, - size_t norm): - - cdef shape_type_c input_shape = input.shape - cdef shape_type_c output_shape = input_shape - - cdef long axis_norm = utils.normalize_axis((axis,), input_shape.size())[0] - output_shape[axis_norm] = output_boundarie - - # convert string type names (dtype) to C enum DPNPFuncType - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(input.dtype) - - # get the FPTR data structure - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_FFT_RFFT_EXT, param1_type, param1_type) - - input_obj = input.get_array() - - # get FPTR function and return type - cdef (DPNPFuncType, void *) ret_type_and_func = utils.get_ret_type_and_func(kernel_data, - input_obj.sycl_device.has_aspect_fp64) - cdef DPNPFuncType return_type = ret_type_and_func[0] - cdef fptr_dpnp_fft_fft_t func = < fptr_dpnp_fft_fft_t > ret_type_and_func[1] - - # create result array with type given by FPTR data - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(output_shape, - return_type, - None, - device=input_obj.sycl_device, - usm_type=input_obj.usm_type, - sycl_queue=input_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - # call FPTR function - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - input.get_data(), - result.get_data(), - input_shape.data(), - output_shape.data(), - input_shape.size(), - axis_norm, - input_boundarie, - inverse, - norm, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index 4bf26b1eba73..1f4bf0bdf0ea 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -52,7 +52,6 @@ ) from dpnp.fft.dpnp_algo_fft import ( dpnp_fft_deprecated, - dpnp_rfft, ) from .dpnp_utils_fft import ( @@ -166,7 +165,9 @@ def fft(a, n=None, axis=-1, norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fft(a, forward=True, n=n, axis=axis, norm=norm, out=out) + return dpnp_fft( + a, forward=True, c2c=True, n=n, axis=axis, norm=norm, out=out + ) def fft2(x, s=None, axes=(-2, -1), norm=None): @@ -539,7 +540,9 @@ def ifft(a, n=None, axis=-1, norm=None, out=None): """ dpnp.check_supported_arrays_type(a) - return dpnp_fft(a, forward=False, n=n, axis=axis, norm=norm, out=out) + return dpnp_fft( + a, forward=False, c2c=True, n=n, axis=axis, norm=norm, out=out + ) def ifft2(x, s=None, axes=(-2, -1), norm=None): @@ -744,66 +747,109 @@ def ihfft(x, n=None, axis=-1, norm=None): return call_origin(numpy.fft.ihfft, x, n, axis, norm) -def irfft(x, n=None, axis=-1, norm=None): +def irfft(a, n=None, axis=-1, norm=None, out=None): """ - Compute the one-dimensional inverse discrete Fourier Transform for real - input. + Computes the inverse of :obj:`dpnp.fft.rfft`. + + This function computes the inverse of the one-dimensional `n`-point + discrete Fourier Transform of real input computed by :obj:`dpnp.fft.rfft`. + In other words, ``irfft(rfft(a), len(a)) == a`` to within numerical + accuracy. (See Notes below for why ``len(a)`` is necessary here.) + + The input is expected to be in the form returned by :obj:`dpnp.fft.rfft`, + i.e. the real zero-frequency term followed by the complex positive + frequency terms in order of increasing frequency. Since the discrete + Fourier Transform of real input is Hermitian-symmetric, the negative + frequency terms are taken to be the complex conjugates of the corresponding + positive frequency terms. For full documentation refer to :obj:`numpy.fft.irfft`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + n : {None, int}, optional + Length of the transformed axis of the output. + For `n` output points, ``n//2+1`` input points are necessary. If the + input is longer than this, it is cropped. If it is shorter than this, + it is padded with zeros. If `n` is not given, it is taken to be + ``2*(m-1)`` where ``m`` is the length of the input along the axis + specified by `axis`. Default: ``None``. + axis : int, optional + Axis over which to compute the FFT. If not given, the last axis is + used. Default: ``-1``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray, usm_ndarray}, optional + If provided, the result will be placed in this array. It should be + of the appropriate shape and dtype. + Default: ``None``. - """ + Returns + ------- + out : dpnp.ndarray + The truncated or zero-padded input, transformed along the axis + indicated by `axis`, or the last one if `axis` is not specified. + The length of the transformed axis is `n`, or, if `n` is not given, + ``2*(m-1)`` where ``m`` is the length of the transformed axis of the + input. To get an odd number of output points, `n` must be specified. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - # TODO: enable implementation - # pylint: disable=condition-evals-to-constant - if x_desc and 0: - norm_ = get_validated_norm(norm) + See Also + -------- + :obj:`dpnp.fft` : For definition of the DFT and conventions used. + :obj:`dpnp.fft.rfft` : The one-dimensional FFT of real input, of which + :obj:`dpnp.fft.irfft` is inverse. + :obj:`dpnp.fft.fft` : The one-dimensional FFT of general (complex) input. + :obj:`dpnp.fft.irfft2` :The inverse of the two-dimensional FFT of + real input. + :obj:`dpnp.fft.irfftn` : The inverse of the `n`-dimensional FFT of + real input. - if axis is None: - axis_param = -1 # the most right dimension (default value) - else: - axis_param = axis + Notes + ----- + Returns the real valued `n`-point inverse discrete Fourier transform + of `a`, where `a` contains the non-negative frequency terms of a + Hermitian-symmetric sequence. `n` is the length of the result, not the + input. - if n is None: - input_boundarie = x_desc.shape[axis_param] - else: - input_boundarie = n + If you specify an `n` such that `a` must be zero-padded or truncated, the + extra/removed values will be added/removed at high frequencies. One can + thus resample a series to `m` points via Fourier interpolation by: + ``a_resamp = irfft(rfft(a), m)``. - if x_desc.size < 1: - pass # let fallback to handle exception - elif input_boundarie < 1: - pass # let fallback to handle exception - elif norm is not None: - pass - elif n is not None: - pass - else: - output_boundarie = 2 * (input_boundarie - 1) + The correct interpretation of the hermitian input depends on the length of + the original data, as given by `n`. This is because each input shape could + correspond to either an odd or even length signal. By default, + :obj:`dpnp.fft.irfft` assumes an even output length which puts the last + entry at the Nyquist frequency; aliasing with its symmetric counterpart. + By Hermitian symmetry, the value is thus treated as purely real. To avoid + losing information, the correct length of the real input **must** be given. - result = dpnp_rfft( - x_desc, - input_boundarie, - output_boundarie, - axis_param, - True, - norm_.value, - ).get_pyobj() - # TODO: - # tmp = utils.create_output_array(result_shape, result_c_type, out) - # tmp = dpnp.ndarray(result.shape, dtype=dpnp.float64) - # for it in range(tmp.size): - # tmp[it] = result[it].real - return result + Examples + -------- + >>> import dpnp as np + >>> a = np.array([1, -1j, -1, 1j]) + >>> np.fft.ifft(a) + array([0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]) # may vary + >>> np.fft.irfft(a[:-1]) + array([0., 1., 0., 0.]) - return call_origin(numpy.fft.irfft, x, n, axis, norm) + Notice how the last term in the input to the ordinary :obj:`dpnp.fft.ifft` + is the complex conjugate of the second term, and the output has zero + imaginary part everywhere. When calling :obj:`dpnp.fft.irfft`, the negative + frequencies are not specified, and the output array is purely real. + + """ + + dpnp.check_supported_arrays_type(a) + return dpnp_fft( + a, forward=False, c2c=False, n=n, axis=axis, norm=norm, out=out + ) def irfft2(x, s=None, axes=(-2, -1), norm=None): @@ -897,68 +943,94 @@ def irfftn(x, s=None, axes=None, norm=None): return call_origin(numpy.fft.irfftn, x, s, axes, norm) -def rfft(x, n=None, axis=-1, norm=None): +def rfft(a, n=None, axis=-1, norm=None, out=None): """ Compute the one-dimensional discrete Fourier Transform for real input. + This function computes the one-dimensional `n`-point discrete Fourier + Transform (DFT) of a real-valued array by means of an efficient algorithm + called the Fast Fourier Transform (FFT). + For full documentation refer to :obj:`numpy.fft.rfft`. - Limitations - ----------- - Parameter `x` is supported either as :class:`dpnp.ndarray`. - Parameter `norm` is unsupported. - Only `dpnp.float64`, `dpnp.float32`, `dpnp.int64`, `dpnp.int32`, - `dpnp.complex128` data types are supported. - The `dpnp.bool` data type is not supported and will raise a `TypeError` - exception. - Otherwise the function will be executed sequentially on CPU. + Parameters + ---------- + a : {dpnp.ndarray, usm_ndarray} + Input array. + n : {None, int}, optional + Number of points along transformation axis in the input to use. + If `n` is smaller than the length of the input, the input is cropped. + If it is larger, the input is padded with zeros. If `n` is not given, + the length of the input along the axis specified by `axis` is used. + Default: ``None``. + axis : int, optional + Axis over which to compute the FFT. If not given, the last axis is + used. Default: ``-1``. + norm : {None, "backward", "ortho", "forward"}, optional + Normalization mode (see :obj:`dpnp.fft`). + Indicates which direction of the forward/backward pair of transforms + is scaled and with what normalization factor. ``None`` is an alias of + the default option ``"backward"``. + Default: ``"backward"``. + out : {None, dpnp.ndarray or usm_ndarray of complex dtype}, optional + If provided, the result will be placed in this array. It should be + of the appropriate shape and dtype. + Default: ``None``. - """ + Returns + ------- + out : dpnp.ndarray of complex dtype + The truncated or zero-padded input, transformed along the axis + indicated by `axis`, or the last one if `axis` is not specified. + If `n` is even, the length of the transformed axis is ``(n/2)+1``. + If `n` is odd, the length is ``(n+1)/2``. - x_desc = dpnp.get_dpnp_descriptor(x, copy_when_nondefault_queue=False) - if x_desc: - dt = x_desc.dtype - if dpnp.issubdtype(dt, dpnp.bool): - raise TypeError(f"The `{dt}` data type is unsupported.") + See Also + -------- + :obj:`dpnp.fft` : For definition of the DFT and conventions used. + :obj:`dpnp.fft.irfft` : The inverse of :obj:`dpnp.fft.rfft`. + :obj:`dpnp.fft.fft` : The one-dimensional FFT of general (complex) input. + :obj:`dpnp.fft.fftn` : The `n`-dimensional FFT. + :obj:`dpnp.fft.rfftn` : The `n`-dimensional FFT of real input. - norm_ = get_validated_norm(norm) + Notes + ----- + When the DFT is computed for purely real input, the output is + Hermitian-symmetric, i.e. the negative frequency terms are just the complex + conjugates of the corresponding positive-frequency terms, and the + negative-frequency terms are therefore redundant. This function does not + compute the negative frequency terms, and the length of the transformed + axis of the output is therefore ``n//2 + 1``. + + When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains + the zero-frequency term 0*fs, which is real due to Hermitian symmetry. + + If `n` is even, ``A[-1]`` contains the term representing both positive + and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely + real. If `n` is odd, there is no term at fs/2; ``A[-1]`` contains + the largest positive frequency (fs/2*(n-1)/n), and is complex in the + general case. - if axis is None: - axis_param = -1 # the most right dimension (default value) - else: - axis_param = axis + Examples + -------- + >>> import dpnp as np + >>> a = np.array([0, 1, 0, 0]) + >>> np.fft.fft(a) + array([ 1.+0.j, 0.-1.j, -1.+0.j, 0.+1.j]) # may vary + >>> np.fft.rfft(a) + array([ 1.+0.j, 0.-1.j, -1.+0.j]) # may vary - if n is None: - input_boundarie = x_desc.shape[axis_param] - else: - input_boundarie = n + Notice how the final element of the :obj:`dpnp.fft.fft` output is the + complex conjugate of the second element, for real input. + For :obj:`dpnp.fft.rfft`, this symmetry is exploited to compute only the + non-negative frequency terms. - if x_desc.size < 1: - pass # let fallback to handle exception - elif input_boundarie < 1: - pass # let fallback to handle exception - elif axis != -1: - pass - elif norm is not None: - pass - elif n is not None: - pass - elif x_desc.dtype in (numpy.complex128, numpy.complex64): - pass - else: - output_boundarie = ( - input_boundarie // 2 + 1 - ) # rfft specific requirenment - return dpnp_rfft( - x_desc, - input_boundarie, - output_boundarie, - axis_param, - False, - norm_.value, - ).get_pyobj() + """ - return call_origin(numpy.fft.rfft, x, n, axis, norm) + dpnp.check_supported_arrays_type(a) + return dpnp_fft( + a, forward=True, c2c=False, n=n, axis=axis, norm=norm, out=out + ) def rfft2(x, s=None, axes=(-2, -1), norm=None): diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index 413155239ef9..22696e30bd56 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -66,16 +66,22 @@ def _check_norm(norm): ) -def _commit_descriptor(a, in_place, a_strides, index, axes): +def _commit_descriptor(a, in_place, c2c, a_strides, index, axes): """Commit the FFT descriptor for the input array.""" a_shape = a.shape shape = a_shape[index:] strides = (0,) + a_strides[index:] - if a.dtype == dpnp.complex64: - dsc = fi.Complex64Descriptor(shape) - else: - dsc = fi.Complex128Descriptor(shape) + if c2c: # c2c FFT + if a.dtype == dpnp.complex64: + dsc = fi.Complex64Descriptor(shape) + else: + dsc = fi.Complex128Descriptor(shape) + else: # r2c/c2r FFT + if a.dtype in [dpnp.float32, dpnp.complex64]: + dsc = fi.Real32Descriptor(shape) + else: + dsc = fi.Real64Descriptor(shape) dsc.fwd_strides = strides dsc.bwd_strides = dsc.fwd_strides @@ -89,7 +95,7 @@ def _commit_descriptor(a, in_place, a_strides, index, axes): return dsc -def _compute_result(dsc, a, out, forward, a_strides): +def _compute_result(dsc, a, out, forward, c2c, a_strides): """Compute the result of the FFT.""" exec_q = a.sycl_queue @@ -99,6 +105,8 @@ def _compute_result(dsc, a, out, forward, a_strides): a_usm = dpnp.get_usm_ndarray(a) if dsc.transform_in_place: # in-place transform + # TODO: investigate the performance of in-place implementation + # for r2c/c2r ht_fft_event, fft_event = fi._fft_in_place( dsc, a_usm, forward, depends=dep_evs ) @@ -114,9 +122,29 @@ def _compute_result(dsc, a, out, forward, a_strides): else: # Result array that is used in OneMKL must have the exact same # stride as input array + + if c2c: # c2c FFT + out_shape = a.shape + out_dtype = a.dtype + else: + if forward: # r2c FFT + tmp = numpy.floor_divide(a.shape[-1], 2) + 1 + out_shape = a.shape[:-1] + (tmp,) + out_dtype = ( + dpnp.complex64 + if a.dtype == dpnp.float32 + else dpnp.complex128 + ) + else: # c2r FFT + out_shape = a.shape # a is already zero-padded + out_dtype = ( + dpnp.float32 + if a.dtype == dpnp.complex64 + else dpnp.float64 + ) result = dpnp_array( - a.shape, - dtype=a.dtype, + out_shape, + dtype=out_dtype, strides=a_strides, usm_type=a.usm_type, sycl_queue=exec_q, @@ -132,24 +160,29 @@ def _compute_result(dsc, a, out, forward, a_strides): return result -def _copy_array(x): +def _copy_array(x, complex_input): """ Creating a C-contiguous copy of input array if input array has a negative stride or it does not have a complex data types. In this situation, an in-place FFT can be performed. """ dtype = x.dtype - copy_flag = False if numpy.min(x.strides) < 0: # negative stride is not allowed in OneMKL FFT copy_flag = True - elif not dpnp.issubdtype(dtype, dpnp.complexfloating): - # if input is not complex, convert to complex + elif complex_input and not dpnp.issubdtype(dtype, dpnp.complexfloating): + # c2c/c2r FFT, if input is not complex, convert to complex copy_flag = True if dtype == dpnp.float32: dtype = dpnp.complex64 else: dtype = map_dtype_to_device(dpnp.complex128, x.sycl_device) + elif not complex_input and dtype not in [dpnp.float32, dpnp.float64]: + # r2c FFT, if input is not float dtype, convert to float + copy_flag = True + dtype = map_dtype_to_device(dpnp.float64, x.sycl_device) + else: + copy_flag = False if copy_flag: x_copy = dpnp.empty_like(x, dtype=dtype, order="C") @@ -171,12 +204,10 @@ def _copy_array(x): return x, copy_flag -def _fft(a, norm, out, forward, in_place, axes=None): +def _fft(a, norm, out, forward, in_place, c2c, axes=None): """Calculates FFT of the input array along the specified axes.""" index = 0 - if not in_place and out is not None: - in_place = dpnp.are_same_logical_tensors(a, out) if axes is not None: # batch_fft len_axes = 1 if isinstance(axes, int) else len(axes) local_axes = numpy.arange(-len_axes, 0) @@ -187,12 +218,13 @@ def _fft(a, norm, out, forward, in_place, axes=None): index = 1 a_strides = _standardize_strides_to_nonzero(a.strides, a.shape) - dsc = _commit_descriptor(a, in_place, a_strides, index, axes) - res = _compute_result(dsc, a, out, forward, a_strides) - res = _scale_result(res, norm, forward, index) + dsc = _commit_descriptor(a, in_place, c2c, a_strides, index, axes) + res = _compute_result(dsc, a, out, forward, c2c, a_strides) + res = _scale_result(res, a.shape, norm, forward, index) if axes is not None: # batch_fft - res = dpnp.reshape(res, a_shape_orig) + tmp_shape = a_shape_orig[:-1] + (res.shape[-1],) + res = dpnp.reshape(res, tmp_shape) res = dpnp.moveaxis(res, local_axes, axes) result = dpnp.get_result_array(res, out=out, casting="same_kind") @@ -204,9 +236,9 @@ def _fft(a, norm, out, forward, in_place, axes=None): return result -def _scale_result(res, norm, forward, index): +def _scale_result(res, a_shape, norm, forward, index): """Scale the result of the FFT according to `norm`.""" - scale = numpy.prod(res.shape[index:], dtype=res.real.dtype) + scale = numpy.prod(a_shape[index:], dtype=res.real.dtype) norm_factor = 1 if norm == "ortho": norm_factor = numpy.sqrt(scale) @@ -261,7 +293,7 @@ def _truncate_or_pad(a, shape, axes): return a -def _validate_out_keyword(a, out): +def _validate_out_keyword(a, out, axis, c2r, r2c): """Validate out keyword argument.""" if out is not None: dpnp.check_supported_arrays_type(out) @@ -273,31 +305,56 @@ def _validate_out_keyword(a, out): "Input and output allocation queues are not compatible" ) - if out.shape != a.shape: - raise ValueError("output array has incorrect shape.") - - if not dpnp.issubdtype(out.dtype, dpnp.complexfloating): - raise TypeError("output array has incorrect data type.") - - -def dpnp_fft(a, forward, n=None, axis=-1, norm=None, out=None): + # validate out shape + if r2c: + if out.shape[axis] != (a.shape[axis] // 2 + 1): + raise ValueError("output array has incorrect shape.") + else: # c2c/c2r FFT, for c2r input is already zero-padded + if out.shape != a.shape: + raise ValueError("output array has incorrect shape.") + + # validate out data type + if c2r: + if not dpnp.issubdtype(out.dtype, dpnp.floating): + raise TypeError( + "output array should have real floating data type." + ) + else: # c2c/r2c FFT + if not dpnp.issubdtype(out.dtype, dpnp.complexfloating): + raise TypeError("output array should have complex data type.") + + +def dpnp_fft(a, forward, c2c, n=None, axis=-1, norm=None, out=None): """Calculates 1-D FFT of the input array along axis""" a_ndim = a.ndim if a_ndim == 0: raise ValueError("Input array must be at least 1D") + r2c = not c2c and forward + c2r = not c2c and not forward + if r2c and dpnp.issubdtype(a.dtype, dpnp.complexfloating): + raise TypeError("Input array must be real") + axis = normalize_axis_index(axis, a_ndim) if n is None: - n = a.shape[axis] + if c2r: + n = (a.shape[axis] - 1) * 2 + else: + n = a.shape[axis] elif not isinstance(n, int): raise TypeError("`n` should be None or an integer") if n < 1: raise ValueError(f"Invalid number of FFT data points ({n}) specified") + _check_norm(norm) a = _truncate_or_pad(a, n, axis) - _validate_out_keyword(a, out) - a, in_place = _copy_array(a) + _validate_out_keyword(a, out, axis, c2r, r2c) + # if input array is copied, in-place FFT can be used + a, in_place = _copy_array(a, c2c or c2r) + if not in_place and out is not None: + # if input is also given for out, in-place FFT can be used + in_place = dpnp.are_same_logical_tensors(a, out) if a.size == 0: return dpnp.get_result_array(a, out=out, casting="same_kind") @@ -305,12 +362,13 @@ def dpnp_fft(a, forward, n=None, axis=-1, norm=None, out=None): # non-batch FFT axis = None if a_ndim == 1 else axis - _check_norm(norm) return _fft( a, norm=norm, out=out, forward=forward, - in_place=in_place, + # TODO: currently in-place is only implemented for c2c + in_place=in_place and c2c, + c2c=c2c, axes=axis, ) diff --git a/tests/test_fft.py b/tests/test_fft.py index 4091eb6a8f8d..cd2674546f5c 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -6,10 +6,20 @@ from numpy.testing import assert_raises import dpnp +from dpnp.dpnp_utils import map_dtype_to_device -from .helper import assert_dtype_allclose, get_all_dtypes, get_complex_dtypes +from .helper import ( + assert_dtype_allclose, + get_all_dtypes, + get_complex_dtypes, + get_float_dtypes, +) +# TODO: `assert_dtype_allclose` calls in this file have `check_only_type_kind=True` +# since stock NumPy is currently used in public CI for code coverege which +# always returns complex128/float64 for FFT functions, but Intel® NumPy and +# dpnp return complex64/float32 if input is complex64/float32 class TestFft: def setup_method(self): numpy.random.seed(42) @@ -289,13 +299,13 @@ def test_fft_error(self, xp): # 0-D input a = xp.array(3) # dpnp and Intel® NumPy return ValueError - # vanilla NumPy return IndexError + # stock NumPy returns IndexError assert_raises((ValueError, IndexError), xp.fft.fft, a) # n is not int a = xp.ones((4, 3)) if xp == dpnp: - # dpnp and vanilla NumPy return TypeError + # dpnp and stock NumPy return TypeError # Intel® NumPy returns SystemError for Python 3.10 and 3.11 # and no error for Python 3.9 assert_raises(TypeError, xp.fft.fft, a, n=5.0) @@ -321,13 +331,16 @@ def test_fft_validate_out(self): out = dpnp.empty((11,), dtype=dpnp.complex64) assert_raises(ValueError, dpnp.fft.fft, a, out=out) - # Invalid dtype + # Invalid dtype for c2c or r2c FFT a = dpnp.ones((10,), dtype=dpnp.complex64) out = dpnp.empty((10,), dtype=dpnp.float32) assert_raises(TypeError, dpnp.fft.fft, a, out=out) class TestRfft: + def setup_method(self): + numpy.random.seed(42) + @pytest.mark.parametrize( "dtype", get_all_dtypes(no_bool=True, no_complex=True) ) @@ -344,16 +357,236 @@ def test_fft_rfft(self, dtype, shape): assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) @pytest.mark.parametrize( - "func_name", - [ - "rfft", - ], + "dtype", get_all_dtypes(no_bool=True, no_complex=True) ) - def test_fft_invalid_dtype(self, func_name): - a = dpnp.array([True, False, True]) - dpnp_func = getattr(dpnp.fft, func_name) - with pytest.raises(TypeError): - dpnp_func(a) + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + def test_fft_1D(self, dtype, n, norm): + x = dpnp.linspace(-1, 1, 11, dtype=dtype) + a = dpnp.sin(x) + a_np = dpnp.asnumpy(a) + + result = dpnp.fft.rfft(a, n=n, norm=norm) + expected = numpy.fft.rfft(a_np, n=n, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + def test_fft_bool(self, n, norm): + a = dpnp.ones(11, dtype=dpnp.bool) + a_np = dpnp.asnumpy(a) + + result = dpnp.fft.rfft(a, n=n, norm=norm) + expected = numpy.fft.rfft(a_np, n=n, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [-1, 1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_2D_array(self, dtype, n, axis, norm, order): + a_np = numpy.arange(12, dtype=dtype).reshape(3, 4, order=order) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfft(a, n=n, axis=axis, norm=norm) + expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [0, 1, 2]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_3D_array(self, dtype, n, axis, norm, order): + a_np = numpy.arange(24, dtype=dtype).reshape(2, 3, 4, order=order) + a = dpnp.asarray(a_np) + + result = dpnp.fft.rfft(a, n=n, axis=axis, norm=norm) + expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("n", [None, 5, 20]) + def test_fft_usm_ndarray(self, n): + x = dpt.linspace(-1, 1, 11) + a_usm = dpt.asarray(dpt.sin(x)) + a_np = dpt.asnumpy(a_usm) + out_shape = a_usm.shape[0] // 2 + 1 if n is None else n // 2 + 1 + out_dtype = map_dtype_to_device(dpnp.complex128, a_usm.sycl_device) + out = dpt.empty(out_shape, dtype=out_dtype) + + result = dpnp.fft.rfft(a_usm, n=n, out=out) + assert out is result.get_array() + expected = numpy.fft.rfft(a_np, n=n) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + def test_fft_1D_out(self, dtype, n, norm): + x = dpnp.linspace(-1, 1, 11) + a = dpnp.sin(x) + 1j * dpnp.cos(x) + a = dpnp.asarray(a, dtype=dtype) + a_np = dpnp.asnumpy(a) + + out_shape = a.shape[0] // 2 + 1 if n is None else n // 2 + 1 + out_dtype = dpnp.complex64 if dtype == dpnp.float32 else dpnp.complex128 + out = dpnp.empty(out_shape, dtype=out_dtype) + + result = dpnp.fft.rfft(a, n=n, norm=norm, out=out) + assert out is result + expected = numpy.fft.rfft(a_np, n=n, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_float_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [-1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_2D_array_out(self, dtype, n, axis, norm, order): + a_np = numpy.arange(12, dtype=dtype).reshape(3, 4, order=order) + a = dpnp.asarray(a_np) + + out_shape = list(a.shape) + out_shape[axis] = a.shape[axis] // 2 + 1 if n is None else n // 2 + 1 + out_shape = tuple(out_shape) + out_dtype = dpnp.complex64 if dtype == dpnp.float32 else dpnp.complex128 + out = dpnp.empty(out_shape, dtype=out_dtype) + + result = dpnp.fft.rfft(a, n=n, axis=axis, norm=norm, out=out) + assert out is result + expected = numpy.fft.rfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("xp", [numpy, dpnp]) + def test_fft_error(self, xp): + a = xp.ones((4, 3), dtype=xp.complex64) + # invalid dtype of input array for r2c FFT + assert_raises(TypeError, xp.fft.rfft, a) + + def test_fft_validate_out(self): + # Invalid shape for r2c FFT + a = dpnp.ones((10,), dtype=dpnp.float32) + out = dpnp.empty((10,), dtype=dpnp.complex64) + assert_raises(ValueError, dpnp.fft.rfft, a, out=out) + + +class TestIrfft: + def setup_method(self): + numpy.random.seed(42) + + @pytest.mark.parametrize("dtype", get_all_dtypes(no_complex=True)) + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + def test_fft_1D(self, dtype, n, norm): + x = dpnp.linspace(-1, 1, 11, dtype=dtype) + a = dpnp.sin(x) + a_np = dpnp.asnumpy(a) + + result = dpnp.fft.irfft(a, n=n, norm=norm) + expected = numpy.fft.irfft(a_np, n=n, norm=norm) + # check_only_type_kind=True since Intel® NumPy always returns float64 + # but dpnp return float32 if input is float32 + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + def test_fft_1D_complex(self, dtype, n, norm): + x = dpnp.linspace(-1, 1, 11) + a = dpnp.sin(x) + 1j * dpnp.cos(x) + a = dpnp.asarray(a, dtype=dtype) + a_np = dpnp.asnumpy(a) + + result = dpnp.fft.irfft(a, n=n, norm=norm) + expected = numpy.fft.irfft(a_np, n=n, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [-1, 1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_2D_array(self, dtype, n, axis, norm, order): + a_np = numpy.arange(12, dtype=dtype).reshape(3, 4, order=order) + a = dpnp.asarray(a_np) + + result = dpnp.fft.irfft(a, n=n, axis=axis, norm=norm) + expected = numpy.fft.irfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [0, 1, 2]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_3D_array(self, dtype, n, axis, norm, order): + x1 = numpy.random.uniform(-10, 10, 24) + x2 = numpy.random.uniform(-10, 10, 24) + a_np = numpy.array(x1 + 1j * x2, dtype=dtype).reshape( + 2, 3, 4, order=order + ) + a = dpnp.asarray(a_np) + + result = dpnp.fft.irfft(a, n=n, axis=axis, norm=norm) + expected = numpy.fft.irfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("n", [None, 5, 20]) + def test_fft_usm_ndarray(self, n): + x = dpt.linspace(-1, 1, 11) + a = dpt.sin(x) + 1j * dpt.cos(x) + a_usm = dpt.asarray(a, dtype=dpt.complex64) + a_np = dpt.asnumpy(a_usm) + out_shape = n if n is not None else 2 * (a_usm.shape[0] - 1) + out = dpt.empty(out_shape, dtype=a_usm.real.dtype) + + result = dpnp.fft.irfft(a_usm, n=n, out=out) + assert out is result.get_array() + expected = numpy.fft.irfft(a_np, n=n) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 20]) + @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) + def test_fft_1D_out(self, dtype, n, norm): + x = dpnp.linspace(-1, 1, 11) + a = dpnp.sin(x) + 1j * dpnp.cos(x) + a = dpnp.asarray(a, dtype=dtype) + a_np = dpnp.asnumpy(a) + + out_shape = n if n is not None else 2 * (a.shape[0] - 1) + out = dpnp.empty(out_shape, dtype=a.real.dtype) + + result = dpnp.fft.irfft(a, n=n, norm=norm, out=out) + assert out is result + expected = numpy.fft.irfft(a_np, n=n, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + @pytest.mark.parametrize("n", [None, 5, 8]) + @pytest.mark.parametrize("axis", [-1, 0]) + @pytest.mark.parametrize("norm", [None, "forward", "ortho"]) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_fft_1D_on_2D_array_out(self, dtype, n, axis, norm, order): + a_np = numpy.arange(12, dtype=dtype).reshape(3, 4, order=order) + a = dpnp.asarray(a_np) + + out_shape = list(a.shape) + out_shape[axis] = 2 * (a.shape[axis] - 1) if n is None else n + out_shape = tuple(out_shape) + out = dpnp.empty(out_shape, dtype=a.real.dtype) + + result = dpnp.fft.irfft(a, n=n, axis=axis, norm=norm, out=out) + assert out is result + expected = numpy.fft.irfft(a_np, n=n, axis=axis, norm=norm) + assert_dtype_allclose(result, expected, check_only_type_kind=True) + + def test_fft_validate_out(self): + # Invalid dtype for c2r FFT + a = dpnp.ones((10,), dtype=dpnp.complex64) + out = dpnp.empty((18,), dtype=dpnp.complex64) + assert_raises(TypeError, dpnp.fft.irfft, a, out=out) class TestFftfreq: diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 4c6a001273a0..2ce79b7aa7c9 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1218,15 +1218,15 @@ def test_out_multi_dot(device): assert_sycl_queue_equal(result.sycl_queue, exec_q) -@pytest.mark.parametrize("func", ["fft", "ifft"]) +@pytest.mark.parametrize("func", ["fft", "ifft", "rfft", "irfft"]) @pytest.mark.parametrize( "device", valid_devices, ids=[device.filter_string for device in valid_devices], ) def test_fft(func, device): - data = numpy.arange(100, dtype=numpy.complex128) - + dtype = numpy.float64 if func == "rfft" else numpy.complex128 + data = numpy.arange(100, dtype=dtype) dpnp_data = dpnp.array(data, device=device) expected = getattr(numpy.fft, func)(data) diff --git a/tests/test_usm_type.py b/tests/test_usm_type.py index 6cc8d5edd39d..e6effa31623d 100644 --- a/tests/test_usm_type.py +++ b/tests/test_usm_type.py @@ -933,11 +933,11 @@ def test_eigenvalue(func, shape, usm_type): assert a.usm_type == dp_val.usm_type -@pytest.mark.parametrize("func", ["fft", "ifft"]) +@pytest.mark.parametrize("func", ["fft", "ifft", "rfft", "irfft"]) @pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types) def test_fft(func, usm_type): - - dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dp.complex64) + dtype = dp.float32 if func == "rfft" else dp.complex64 + dpnp_data = dp.arange(100, usm_type=usm_type, dtype=dtype) result = getattr(dp.fft, func)(dpnp_data) assert dpnp_data.usm_type == usm_type diff --git a/tests/third_party/cupy/fft_tests/test_fft.py b/tests/third_party/cupy/fft_tests/test_fft.py index 401a7bf26e95..be8f3cc98fe5 100644 --- a/tests/third_party/cupy/fft_tests/test_fft.py +++ b/tests/third_party/cupy/fft_tests/test_fft.py @@ -1,5 +1,4 @@ import functools -import string import unittest import numpy as np @@ -230,39 +229,47 @@ def test_ifftn(self, xp, dtype): return out +@pytest.mark.usefixtures("skip_forward_backward") @testing.parameterize( *testing.product( { "n": [None, 5, 10, 15], "shape": [(10,), (10, 10)], - "norm": [None, "ortho"], + "norm": [None, "backward", "ortho", "forward", ""], } ) ) -@pytest.mark.usefixtures("allow_fall_back_on_numpy") class TestRfft: @testing.for_all_dtypes(no_complex=True) @testing.numpy_cupy_allclose( rtol=1e-4, atol=1e-7, + accept_error=ValueError, type_check=False, ) def test_rfft(self, xp, dtype): a = testing.shaped_random(self.shape, xp, dtype) out = xp.fft.rfft(a, n=self.n, norm=self.norm) + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.complex64) + return out @testing.for_all_dtypes() @testing.numpy_cupy_allclose( rtol=1e-4, - atol=1e-7, + atol=2e-6, + accept_error=ValueError, type_check=has_support_aspect64(), ) def test_irfft(self, xp, dtype): a = testing.shaped_random(self.shape, xp, dtype) out = xp.fft.irfft(a, n=self.n, norm=self.norm) + if xp is np and dtype in [np.float16, np.float32, np.complex64]: + out = out.astype(np.float32) + return out From 92d4070e8b0f120b9c2f1deb061cea13f69aab11 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 15 Jul 2024 14:18:30 -0500 Subject: [PATCH 2/7] update tests --- tests/test_fft.py | 5 ++++- tests/third_party/cupy/fft_tests/test_fft.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_fft.py b/tests/test_fft.py index cd2674546f5c..c4d6eb076b07 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -462,7 +462,10 @@ def test_fft_1D_on_2D_array_out(self, dtype, n, axis, norm, order): def test_fft_error(self, xp): a = xp.ones((4, 3), dtype=xp.complex64) # invalid dtype of input array for r2c FFT - assert_raises(TypeError, xp.fft.rfft, a) + if xp == dpnp: + # stock NumPy-1.26 ignores imaginary part + # Intel® NumPy, dpnp, stock NumPy-2.0 return TypeError + assert_raises(TypeError, xp.fft.rfft, a) def test_fft_validate_out(self): # Invalid shape for r2c FFT diff --git a/tests/third_party/cupy/fft_tests/test_fft.py b/tests/third_party/cupy/fft_tests/test_fft.py index be8f3cc98fe5..52a2da0a0033 100644 --- a/tests/third_party/cupy/fft_tests/test_fft.py +++ b/tests/third_party/cupy/fft_tests/test_fft.py @@ -245,7 +245,8 @@ class TestRfft: rtol=1e-4, atol=1e-7, accept_error=ValueError, - type_check=False, + contiguous_check=False, + type_check=has_support_aspect64(), ) def test_rfft(self, xp, dtype): a = testing.shaped_random(self.shape, xp, dtype) @@ -261,6 +262,7 @@ def test_rfft(self, xp, dtype): rtol=1e-4, atol=2e-6, accept_error=ValueError, + contiguous_check=False, type_check=has_support_aspect64(), ) def test_irfft(self, xp, dtype): From d8b4856780510e42fa503ab36225125abc186f1f Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 22 Jul 2024 08:56:41 -0500 Subject: [PATCH 3/7] remove redundant test --- tests/test_sycl_queue.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/test_sycl_queue.py b/tests/test_sycl_queue.py index 2ce79b7aa7c9..7b25fbc48837 100644 --- a/tests/test_sycl_queue.py +++ b/tests/test_sycl_queue.py @@ -1231,34 +1231,10 @@ def test_fft(func, device): expected = getattr(numpy.fft, func)(data) result = getattr(dpnp.fft, func)(dpnp_data) - assert_dtype_allclose(result, expected) expected_queue = dpnp_data.get_array().sycl_queue result_queue = result.get_array().sycl_queue - - assert_sycl_queue_equal(result_queue, expected_queue) - - -@pytest.mark.parametrize("type", ["float32"]) -@pytest.mark.parametrize("shape", [(8, 8)]) -@pytest.mark.parametrize( - "device", - valid_devices, - ids=[device.filter_string for device in valid_devices], -) -def test_fft_rfft(type, shape, device): - np_data = numpy.arange(64, dtype=numpy.dtype(type)).reshape(shape) - dpnp_data = dpnp.array(np_data, device=device) - - np_res = numpy.fft.rfft(np_data) - dpnp_res = dpnp.fft.rfft(dpnp_data) - - assert_dtype_allclose(dpnp_res, np_res, check_only_type_kind=True) - - expected_queue = dpnp_data.get_array().sycl_queue - result_queue = dpnp_res.get_array().sycl_queue - assert_sycl_queue_equal(result_queue, expected_queue) From d3bd05a3be2c11074e67de90370fd3cf2244a90f Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 22 Jul 2024 09:14:33 -0500 Subject: [PATCH 4/7] clean-up --- dpnp/backend/include/dpnp_iface_fptr.hpp | 1 - dpnp/backend/kernels/dpnp_krnl_fft.cpp | 139 ----------------------- 2 files changed, 140 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 79d1f18cc3c7..68120ff42ace 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -92,7 +92,6 @@ enum class DPNPFuncName : size_t DPNP_FN_FFT_FFT, /**< Used in numpy.fft.fft() impl */ DPNP_FN_FFT_FFT_EXT, /**< Used in numpy.fft.fft() impl, requires extra parameters */ - DPNP_FN_FFT_RFFT, /**< Used in numpy.fft.rfft() impl */ DPNP_FN_INITVAL, /**< Used in numpy ones, ones_like, zeros, zeros_like impls */ DPNP_FN_INITVAL_EXT, /**< Used in numpy ones, ones_like, zeros, zeros_like diff --git a/dpnp/backend/kernels/dpnp_krnl_fft.cpp b/dpnp/backend/kernels/dpnp_krnl_fft.cpp index ff4d3873c881..3d81dde39250 100644 --- a/dpnp/backend/kernels/dpnp_krnl_fft.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_fft.cpp @@ -571,133 +571,6 @@ DPCTLSyclEventRef (*dpnp_fft_fft_ext_c)(DPCTLSyclQueueRef, const DPCTLEventVectorRef) = dpnp_fft_fft_c<_DataType_input, _DataType_output>; -template -DPCTLSyclEventRef dpnp_fft_rfft_c(DPCTLSyclQueueRef q_ref, - const void *array1_in, - void *result_out, - const shape_elem_type *input_shape, - const shape_elem_type *result_shape, - size_t shape_size, - long, // axis - long, // input_boundary - size_t inverse, - const size_t norm, - const DPCTLEventVectorRef dep_event_vec_ref) -{ - static_assert(is_complex<_DataType_output>::value, - "Output data type must be a complex type."); - DPCTLSyclEventRef event_ref = nullptr; - - if (!shape_size || !array1_in || !result_out) { - return event_ref; - } - - const size_t result_size = - std::accumulate(result_shape, result_shape + shape_size, 1, - std::multiplies()); - const size_t input_size = - std::accumulate(input_shape, input_shape + shape_size, 1, - std::multiplies()); - - if constexpr (std::is_same<_DataType_output, std::complex>::value || - std::is_same<_DataType_output, std::complex>::value) - { - if constexpr (std::is_same<_DataType_input, double>::value && - std::is_same<_DataType_output, - std::complex>::value) - { - event_ref = - dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, double, - desc_dp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, - shape_size, input_size, result_size, inverse, norm, 1); - } - /* real-to-complex, single precision */ - else if constexpr (std::is_same<_DataType_input, float>::value && - std::is_same<_DataType_output, - std::complex>::value) - { - event_ref = - dpnp_fft_fft_mathlib_real_to_cmplx_c<_DataType_input, float, - desc_sp_real_t>( - q_ref, array1_in, result_out, input_shape, result_shape, - shape_size, input_size, result_size, inverse, norm, 1); - } - else if constexpr (std::is_same<_DataType_input, int32_t>::value || - std::is_same<_DataType_input, int64_t>::value) - { - using CastType = typename _DataType_output::value_type; - - CastType *array1_copy = reinterpret_cast( - dpnp_memory_alloc_c(q_ref, input_size * sizeof(CastType))); - - shape_elem_type *copy_strides = reinterpret_cast( - dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type))); - *copy_strides = 1; - shape_elem_type *copy_shape = reinterpret_cast( - dpnp_memory_alloc_c(q_ref, sizeof(shape_elem_type))); - *copy_shape = input_size; - shape_elem_type copy_shape_size = 1; - event_ref = dpnp_copyto_c<_DataType_input, CastType>( - q_ref, array1_copy, input_size, copy_shape_size, copy_shape, - copy_strides, array1_in, input_size, copy_shape_size, - copy_shape, copy_strides, NULL, dep_event_vec_ref); - DPCTLEvent_WaitAndThrow(event_ref); - DPCTLEvent_Delete(event_ref); - - event_ref = dpnp_fft_fft_mathlib_real_to_cmplx_c< - CastType, CastType, - std::conditional_t::value, - desc_dp_real_t, desc_sp_real_t>>( - q_ref, array1_copy, result_out, input_shape, result_shape, - shape_size, input_size, result_size, inverse, norm, 1); - - DPCTLEvent_WaitAndThrow(event_ref); - DPCTLEvent_Delete(event_ref); - event_ref = nullptr; - - dpnp_memory_free_c(q_ref, array1_copy); - dpnp_memory_free_c(q_ref, copy_strides); - dpnp_memory_free_c(q_ref, copy_shape); - } - } - - return event_ref; -} - -template -void dpnp_fft_rfft_c(const void *array1_in, - void *result1, - const shape_elem_type *input_shape, - const shape_elem_type *output_shape, - size_t shape_size, - long axis, - long input_boundarie, - size_t inverse, - const size_t norm) -{ - DPCTLSyclQueueRef q_ref = reinterpret_cast(&DPNP_QUEUE); - DPCTLEventVectorRef dep_event_vec_ref = nullptr; - DPCTLSyclEventRef event_ref = - dpnp_fft_rfft_c<_DataType_input, _DataType_output>( - q_ref, array1_in, result1, input_shape, output_shape, shape_size, - axis, input_boundarie, inverse, norm, dep_event_vec_ref); - DPCTLEvent_WaitAndThrow(event_ref); - DPCTLEvent_Delete(event_ref); -} - -template -void (*dpnp_fft_rfft_default_c)(const void *, - void *, - const shape_elem_type *, - const shape_elem_type *, - size_t, - long, - long, - size_t, - const size_t) = - dpnp_fft_rfft_c<_DataType_input, _DataType_output>; - void func_map_init_fft_func(func_map_t &fmap) { fmap[DPNPFuncName::DPNP_FN_FFT_FFT][eft_INT][eft_INT] = { @@ -736,17 +609,5 @@ void func_map_init_fft_func(func_map_t &fmap) eft_C128, (void *)dpnp_fft_fft_ext_c, std::complex>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT][eft_INT][eft_INT] = { - eft_C128, - (void *)dpnp_fft_rfft_default_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT][eft_LNG][eft_LNG] = { - eft_C128, - (void *)dpnp_fft_rfft_default_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT][eft_FLT][eft_FLT] = { - eft_C64, (void *)dpnp_fft_rfft_default_c>}; - fmap[DPNPFuncName::DPNP_FN_FFT_RFFT][eft_DBL][eft_DBL] = { - eft_C128, - (void *)dpnp_fft_rfft_default_c>}; - return; } From b72de0074d22d9ca31887981093b7de03279eaaf Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 22 Jul 2024 11:06:11 -0500 Subject: [PATCH 5/7] separate backend structures --- dpnp/backend/extensions/fft/fft_utils.hpp | 27 ++++++++++++++++++----- 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/dpnp/backend/extensions/fft/fft_utils.hpp b/dpnp/backend/extensions/fft/fft_utils.hpp index cb25eb4ac949..d32b619a3510 100644 --- a/dpnp/backend/extensions/fft/fft_utils.hpp +++ b/dpnp/backend/extensions/fft/fft_utils.hpp @@ -31,6 +31,24 @@ namespace dpnp::extensions::fft { namespace mkl_dft = oneapi::mkl::dft; +// Structure to map MKL precision to float/double types +template +struct PrecisionType; + +template <> +struct PrecisionType +{ + using type = float; +}; + +template <> +struct PrecisionType +{ + using type = double; +}; + +// Structure to map combination of precision, domain, and is_forward flag to +// in/out types template struct ScaleType { @@ -43,8 +61,7 @@ struct ScaleType template struct ScaleType { - using prec_type = typename std:: - conditional::type; + using prec_type = typename PrecisionType::type; using type_in = prec_type; using type_out = std::complex; }; @@ -54,8 +71,7 @@ struct ScaleType template struct ScaleType { - using prec_type = typename std:: - conditional::type; + using prec_type = typename PrecisionType::type; using type_in = std::complex; using type_out = prec_type; }; @@ -65,8 +81,7 @@ struct ScaleType template struct ScaleType { - using prec_type = typename std:: - conditional::type; + using prec_type = typename PrecisionType::type; using type_in = std::complex; using type_out = std::complex; }; From c0630a23865618ba737419e87a7e284d02a4a37b Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 22 Jul 2024 10:54:57 -0500 Subject: [PATCH 6/7] address comments --- dpnp/fft/dpnp_iface_fft.py | 13 +++++++------ dpnp/fft/dpnp_utils_fft.py | 34 ++++++++++++++++++++-------------- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/dpnp/fft/dpnp_iface_fft.py b/dpnp/fft/dpnp_iface_fft.py index 1f4bf0bdf0ea..c117e2b19770 100644 --- a/dpnp/fft/dpnp_iface_fft.py +++ b/dpnp/fft/dpnp_iface_fft.py @@ -166,7 +166,7 @@ def fft(a, n=None, axis=-1, norm=None, out=None): dpnp.check_supported_arrays_type(a) return dpnp_fft( - a, forward=True, c2c=True, n=n, axis=axis, norm=norm, out=out + a, forward=True, real=False, n=n, axis=axis, norm=norm, out=out ) @@ -541,7 +541,7 @@ def ifft(a, n=None, axis=-1, norm=None, out=None): dpnp.check_supported_arrays_type(a) return dpnp_fft( - a, forward=False, c2c=True, n=n, axis=axis, norm=norm, out=out + a, forward=False, real=False, n=n, axis=axis, norm=norm, out=out ) @@ -848,7 +848,7 @@ def irfft(a, n=None, axis=-1, norm=None, out=None): dpnp.check_supported_arrays_type(a) return dpnp_fft( - a, forward=False, c2c=False, n=n, axis=axis, norm=norm, out=out + a, forward=False, real=True, n=n, axis=axis, norm=norm, out=out ) @@ -1002,8 +1002,9 @@ def rfft(a, n=None, axis=-1, norm=None, out=None): compute the negative frequency terms, and the length of the transformed axis of the output is therefore ``n//2 + 1``. - When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains - the zero-frequency term 0*fs, which is real due to Hermitian symmetry. + When ``A = dpnp.fft.rfft(a)`` and fs is the sampling frequency, ``A[0]`` + contains the zero-frequency term 0*fs, which is real due to Hermitian + symmetry. If `n` is even, ``A[-1]`` contains the term representing both positive and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely @@ -1029,7 +1030,7 @@ def rfft(a, n=None, axis=-1, norm=None, out=None): dpnp.check_supported_arrays_type(a) return dpnp_fft( - a, forward=True, c2c=False, n=n, axis=axis, norm=norm, out=out + a, forward=True, real=True, n=n, axis=axis, norm=norm, out=out ) diff --git a/dpnp/fft/dpnp_utils_fft.py b/dpnp/fft/dpnp_utils_fft.py index 22696e30bd56..01655731263b 100644 --- a/dpnp/fft/dpnp_utils_fft.py +++ b/dpnp/fft/dpnp_utils_fft.py @@ -106,7 +106,7 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides): if dsc.transform_in_place: # in-place transform # TODO: investigate the performance of in-place implementation - # for r2c/c2r + # for r2c/c2r, see SAT-7154 ht_fft_event, fft_event = fi._fft_in_place( dsc, a_usm, forward, depends=dep_evs ) @@ -128,7 +128,7 @@ def _compute_result(dsc, a, out, forward, c2c, a_strides): out_dtype = a.dtype else: if forward: # r2c FFT - tmp = numpy.floor_divide(a.shape[-1], 2) + 1 + tmp = a.shape[-1] // 2 + 1 out_shape = a.shape[:-1] + (tmp,) out_dtype = ( dpnp.complex64 @@ -178,7 +178,8 @@ def _copy_array(x, complex_input): else: dtype = map_dtype_to_device(dpnp.complex128, x.sycl_device) elif not complex_input and dtype not in [dpnp.float32, dpnp.float64]: - # r2c FFT, if input is not float dtype, convert to float + # r2c FFT, if input is integer or float16 dtype, convert to + # float32 or float64 depending on device capabilities copy_flag = True dtype = map_dtype_to_device(dpnp.float64, x.sycl_device) else: @@ -198,9 +199,9 @@ def _copy_array(x, complex_input): depends=dep_evs, ) _manager.add_event_pair(ht_copy_ev, copy_ev) + x = x_copy - # if copying is done, FFT can be in-place (copy_flag = in_place flag) - return x_copy, copy_flag + # if copying is done, FFT can be in-place (copy_flag = in_place flag) return x, copy_flag @@ -306,12 +307,16 @@ def _validate_out_keyword(a, out, axis, c2r, r2c): ) # validate out shape + expected_shape = a.shape if r2c: - if out.shape[axis] != (a.shape[axis] // 2 + 1): - raise ValueError("output array has incorrect shape.") - else: # c2c/c2r FFT, for c2r input is already zero-padded - if out.shape != a.shape: - raise ValueError("output array has incorrect shape.") + expected_shape = list(a.shape) + expected_shape[axis] = a.shape[axis] // 2 + 1 + expected_shape = tuple(expected_shape) + if out.shape != expected_shape: + raise ValueError( + "output array has incorrect shape, expected " + f"{expected_shape}, got {out.shape}." + ) # validate out data type if c2r: @@ -324,15 +329,16 @@ def _validate_out_keyword(a, out, axis, c2r, r2c): raise TypeError("output array should have complex data type.") -def dpnp_fft(a, forward, c2c, n=None, axis=-1, norm=None, out=None): +def dpnp_fft(a, forward, real, n=None, axis=-1, norm=None, out=None): """Calculates 1-D FFT of the input array along axis""" a_ndim = a.ndim if a_ndim == 0: raise ValueError("Input array must be at least 1D") - r2c = not c2c and forward - c2r = not c2c and not forward + c2c = not real # complex-to-complex FFT + r2c = real and forward # real-to-complex FFT + c2r = real and not forward # complex-to-real FFT if r2c and dpnp.issubdtype(a.dtype, dpnp.complexfloating): raise TypeError("Input array must be real") @@ -367,7 +373,7 @@ def dpnp_fft(a, forward, c2c, n=None, axis=-1, norm=None, out=None): norm=norm, out=out, forward=forward, - # TODO: currently in-place is only implemented for c2c + # TODO: currently in-place is only implemented for c2c, see SAT-7154 in_place=in_place and c2c, c2c=c2c, axes=axis, From 6c17c09ce65167c0a93a79947520d297955eeddd Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 26 Jul 2024 16:16:04 -0500 Subject: [PATCH 7/7] mute tests on gpu with fp64 support --- tests/test_fft.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_fft.py b/tests/test_fft.py index c4d6eb076b07..1f7b389ed75c 100644 --- a/tests/test_fft.py +++ b/tests/test_fft.py @@ -13,8 +13,14 @@ get_all_dtypes, get_complex_dtypes, get_float_dtypes, + is_cpu_device, ) +# aspects of default device: +_def_device = dpctl.SyclQueue().sycl_device +_def_dev_has_fp64 = _def_device.has_aspect_fp64 +is_gpu_with_fp64 = not is_cpu_device() and _def_dev_has_fp64 + # TODO: `assert_dtype_allclose` calls in this file have `check_only_type_kind=True` # since stock NumPy is currently used in public CI for code coverege which @@ -492,6 +498,7 @@ def test_fft_1D(self, dtype, n, norm): # but dpnp return float32 if input is float32 assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.skipif(is_gpu_with_fp64, reason="MKLD17702") @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"]) @@ -518,6 +525,7 @@ def test_fft_1D_on_2D_array(self, dtype, n, axis, norm, order): expected = numpy.fft.irfft(a_np, n=n, axis=axis, norm=norm) assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.skipif(is_gpu_with_fp64, reason="MKLD17702") @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize("n", [None, 5, 8]) @pytest.mark.parametrize("axis", [0, 1, 2]) @@ -535,6 +543,7 @@ def test_fft_1D_on_3D_array(self, dtype, n, axis, norm, order): expected = numpy.fft.irfft(a_np, n=n, axis=axis, norm=norm) assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.skipif(is_gpu_with_fp64, reason="MKLD17702") @pytest.mark.parametrize("n", [None, 5, 20]) def test_fft_usm_ndarray(self, n): x = dpt.linspace(-1, 1, 11) @@ -549,6 +558,7 @@ def test_fft_usm_ndarray(self, n): expected = numpy.fft.irfft(a_np, n=n) assert_dtype_allclose(result, expected, check_only_type_kind=True) + @pytest.mark.skipif(is_gpu_with_fp64, reason="MKLD17702") @pytest.mark.parametrize("dtype", get_complex_dtypes()) @pytest.mark.parametrize("n", [None, 5, 20]) @pytest.mark.parametrize("norm", ["forward", "backward", "ortho"])