Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to Float64. E.g. for erfc we will first call sqrt(π) which results in Float64, and instead of promoting Irrational to what we expected the output-type to be, we end up promoting the output-type to Float64 (if we're using floats with lower precision):
julia> using SpecialFunctions, ChainRulesCore
julia> y, ȳ = ChainRulesCore.frule((ChainRulesCore.NO_FIELDS, 1f0), SpecialFunctions.erfc, 1f0)
(0.1572992f0, -0.41510750774498784)
julia> typeof(y), typeof(ȳ)
(Float32, Float64)
This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).
Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?
Due to the usage of irrational numbers, some of the functions have adjoints which will mistakenly promote the numerical precision of the derivative/gradient. In particular this occurs because certain impls will first call act on the irrational number which often by default ends up converting the irrational number to
Float64. E.g. forerfcwe will first callsqrt(π)which results inFloat64, and instead of promotingIrrationalto what we expected the output-type to be, we end up promoting the output-type toFloat64(if we're using floats with lower precision):This is essentially the same issue as in DiffRules (JuliaDiff/DiffRules.jl#55).
Anyone got a better idea on what to do here, or should I just make a similar PR to SpecialFunctions.jl?