Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
name = "FastChebInterp"
uuid = "cf66c380-9a80-432c-aff8-4f9c79c0bdde"
version = "1.0"
version = "1.1"

[deps]
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

[compat]
ChainRulesCore = "1"
ChainRulesTestUtils = "1"
FFTW = "1.0"
StaticArrays = "0.12, 1.0"
julia = "1.3"

[extras]
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["ChainRulesTestUtils", "Random", "Test"]
4 changes: 3 additions & 1 deletion src/FastChebInterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ A multidimensional Chebyshev-polynomial interpolation object.
Given a `c::ChebPoly`, you can evaluate it at a point `x`
with `c(x)`, where `x` is a vector (or a scalar if `c` is 1d).
"""
struct ChebPoly{N,T,Td<:Real}
struct ChebPoly{N,T,Td<:Real} <: Function
coefs::Array{T,N} # chebyshev coefficients
lb::SVector{N,Td} # lower/upper bounds
ub::SVector{N,Td} # of the domain
Expand All @@ -48,9 +48,11 @@ function Base.show(io::IO, c::ChebPoly)
end
end
Base.ndims(c::ChebPoly) = ndims(c.coefs)
Base.zero(c::ChebPoly{N,T,Td}) where {N,T,Td} = ChebPoly{N,T,Td}(zero(c.coefs), c.lb, c.ub)

include("interp.jl")
include("regression.jl")
include("eval.jl")
include("chainrules.jl")

end # module
39 changes: 39 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import ChainRulesCore
using ChainRulesCore: ProjectTo, NoTangent, @not_implemented

function ChainRulesCore.rrule(c::ChebPoly{1}, x::Real)
project_x = ProjectTo(x)
y, ∇y = chebgradient(c, x)
chebpoly_pullback(∂y) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(real(∇y' * ∂y))
y, chebpoly_pullback
end

function ChainRulesCore.rrule(c::ChebPoly, x::AbstractVector{<:Real})
project_x = ProjectTo(x)
y, J = chebjacobian(c, x)
chebpoly_pullback(Δy) = @not_implemented("no rrule for changes in ChebPoly itself"), project_x(vec(real(J' * Δy)))
y, chebpoly_pullback
end

ChainRulesCore.frule((Δself, Δx), c::ChebPoly{1}, x::Real) =
ChainRulesCore.frule((Δself, SVector{1}(Δx)), c, SVector{1}(x))

function ChainRulesCore.frule((Δself, Δx), c::ChebPoly, x::AbstractVector)
y, J = chebjacobian(c, x)
if Δself isa ChainRulesCore.AbstractZero # Δself == 0
Δy = J * Δx
return y, y isa Number ? Δy[1] : Δy
else # need derivatives with respect to changes in c
# additional Δx from changes in bound:
# --- recall x0 = @. (x - c.lb) * 2 / (c.ub - c.lb) - 1,
# but note that J already includes 2 / (c.ub - c.lb)
d2 = @. (x - c.lb) / (c.ub - c.lb)
Δx′ = @. Δx + (d2 - 1) * Δself.lb - d2 * Δself.ub
Δy = J * Δx′

# dependence on coefs is linear
Δcoefs = typeof(c)(Δself.coefs, c.lb, c.ub)

return y, (y isa Number ? Δy[1] : Δy) + Δcoefs(x)
end
end
10 changes: 9 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
using Test, FastChebInterp, StaticArrays
using Test, FastChebInterp, StaticArrays, Random, ChainRulesTestUtils

# similar to ≈, but acts elementwise on tuples
≈′(a::Tuple, b::Tuple; kws...) where {N} = length(a) == length(b) && all(xy -> isapprox(xy[1],xy[2]; kws...), zip(a,b))

Random.seed!(314159) # make chainrules tests deterministic

@testset "1d test" begin
lb,ub = -0.3, 0.9
f(x) = exp(x) / (1 + 2x^2)
Expand All @@ -14,6 +16,8 @@ using Test, FastChebInterp, StaticArrays
x1 = 0.2
@test interp(x1) ≈ f(x1)
@test chebgradient(interp, x1) ≈′ (f(x1), f′(x1))
test_frule(interp, x1)
test_rrule(interp, x1)
end

@testset "2d test" begin
Expand All @@ -29,6 +33,8 @@ end
@test interp(x1) ≈ interp0(x1) rtol=1e-15
@test all(n -> n[1] < n[2], zip(size(interp.coefs), size(interp0.coefs)))
@test chebgradient(interp, x1) ≈′ (f(x1), ∇f(x1))
test_frule(interp, x1)
test_rrule(interp, x1)

# univariate function in 2d should automatically drop down to univariate polynomial
f1(x) = exp(x[1]) / (1 + 2x[1]^2)
Expand All @@ -42,6 +48,8 @@ end
interp2 = chebinterp(f2.(x), lb, ub)
@test interp2(x1) ≈ f2(x1)
@test chebjacobian(interp2, x1) ≈′ (f2(x1), ∇f2(x1))
test_frule(interp2, x1)
test_rrule(interp2, x1)

# chebinterp_v1
av1 = Array{ComplexF64}(undef, 2, size(x)...)
Expand Down