diff --git a/Project.toml b/Project.toml index 2004a4f..e9e4d61 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,23 @@ name = "FastChebInterp" uuid = "cf66c380-9a80-432c-aff8-4f9c79c0bdde" -version = "1.0" +version = "1.1" [deps] FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" [compat] +ChainRulesCore = "1" +ChainRulesTestUtils = "1" FFTW = "1.0" StaticArrays = "0.12, 1.0" julia = "1.3" [extras] +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["ChainRulesTestUtils", "Random", "Test"] diff --git a/src/FastChebInterp.jl b/src/FastChebInterp.jl index 74f1805..cf8672d 100644 --- a/src/FastChebInterp.jl +++ b/src/FastChebInterp.jl @@ -34,7 +34,7 @@ A multidimensional Chebyshev-polynomial interpolation object. Given a `c::ChebPoly`, you can evaluate it at a point `x` with `c(x)`, where `x` is a vector (or a scalar if `c` is 1d). """ -struct ChebPoly{N,T,Td<:Real} +struct ChebPoly{N,T,Td<:Real} <: Function coefs::Array{T,N} # chebyshev coefficients lb::SVector{N,Td} # lower/upper bounds ub::SVector{N,Td} # of the domain @@ -48,9 +48,11 @@ function Base.show(io::IO, c::ChebPoly) end end Base.ndims(c::ChebPoly) = ndims(c.coefs) +Base.zero(c::ChebPoly{N,T,Td}) where {N,T,Td} = ChebPoly{N,T,Td}(zero(c.coefs), c.lb, c.ub) include("interp.jl") include("regression.jl") include("eval.jl") +include("chainrules.jl") end # module diff --git a/src/chainrules.jl b/src/chainrules.jl new file mode 100644 index 0000000..608d1b6 --- /dev/null +++ b/src/chainrules.jl @@ -0,0 +1,39 @@ +import ChainRulesCore +using ChainRulesCore: ProjectTo, NoTangent, @not_implemented + +function ChainRulesCore.rrule(c::ChebPoly{1}, x::Real) + project_x = ProjectTo(x) + y, ∇y = chebgradient(c, x) + chebpoly_pullback(∂y) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(real(∇y' * ∂y)) + y, chebpoly_pullback +end + +function ChainRulesCore.rrule(c::ChebPoly, x::AbstractVector{<:Real}) + project_x = ProjectTo(x) + y, J = chebjacobian(c, x) + chebpoly_pullback(Δy) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(vec(real(J' * Δy))) + y, chebpoly_pullback +end + +ChainRulesCore.frule((Δself, Δx), c::ChebPoly{1}, x::Real) = + ChainRulesCore.frule((Δself, SVector{1}(Δx)), c, SVector{1}(x)) + +function ChainRulesCore.frule((Δself, Δx), c::ChebPoly, x::AbstractVector) + y, J = chebjacobian(c, x) + if Δself isa ChainRulesCore.AbstractZero # Δself == 0 + Δy = J * Δx + return y, y isa Number ? Δy[1] : Δy + else # need derivatives with respect to changes in c + # additional Δx from changes in bound: + # --- recall x0 = @. (x - c.lb) * 2 / (c.ub - c.lb) - 1, + # but note that J already includes 2 / (c.ub - c.lb) + d2 = @. (x - c.lb) / (c.ub - c.lb) + Δx′ = @. Δx + (d2 - 1) * Δself.lb - d2 * Δself.ub + Δy = J * Δx′ + + # dependence on coefs is linear + Δcoefs = typeof(c)(Δself.coefs, c.lb, c.ub) + + return y, (y isa Number ? Δy[1] : Δy) + Δcoefs(x) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 9793e64..a40713c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,10 @@ -using Test, FastChebInterp, StaticArrays +using Test, FastChebInterp, StaticArrays, Random, ChainRulesTestUtils # similar to ≈, but acts elementwise on tuples ≈′(a::Tuple, b::Tuple; kws...) where {N} = length(a) == length(b) && all(xy -> isapprox(xy[1],xy[2]; kws...), zip(a,b)) +Random.seed!(314159) # make chainrules tests deterministic + @testset "1d test" begin lb,ub = -0.3, 0.9 f(x) = exp(x) / (1 + 2x^2) @@ -14,6 +16,8 @@ using Test, FastChebInterp, StaticArrays x1 = 0.2 @test interp(x1) ≈ f(x1) @test chebgradient(interp, x1) ≈′ (f(x1), f′(x1)) + test_frule(interp, x1) + test_rrule(interp, x1) end @testset "2d test" begin @@ -29,6 +33,8 @@ end @test interp(x1) ≈ interp0(x1) rtol=1e-15 @test all(n -> n[1] < n[2], zip(size(interp.coefs), size(interp0.coefs))) @test chebgradient(interp, x1) ≈′ (f(x1), ∇f(x1)) + test_frule(interp, x1) + test_rrule(interp, x1) # univariate function in 2d should automatically drop down to univariate polynomial f1(x) = exp(x[1]) / (1 + 2x[1]^2) @@ -42,6 +48,8 @@ end interp2 = chebinterp(f2.(x), lb, ub) @test interp2(x1) ≈ f2(x1) @test chebjacobian(interp2, x1) ≈′ (f2(x1), ∇f2(x1)) + test_frule(interp2, x1) + test_rrule(interp2, x1) # chebinterp_v1 av1 = Array{ComplexF64}(undef, 2, size(x)...)