Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inferring Any on gradient w.r.t. wrapper of recursive type #1514

Open
MilesCranmer opened this issue Jul 1, 2024 · 0 comments
Open

Inferring Any on gradient w.r.t. wrapper of recursive type #1514

MilesCranmer opened this issue Jul 1, 2024 · 0 comments

Comments

@MilesCranmer
Copy link

MilesCranmer commented Jul 1, 2024

I'm trying to fix inference issues within a Zygote gradient for DynamicExpressions.jl – with the goal of using fast AD in SymbolicRegression.jl. Right now Zygote.gradient is inferring Any as a return value and I can't figure out why. The weird thing is that I can infer fine on internal functions (which have a custom chain rule). It's only the outermost wrapper function that fails to infer.

Context – for reference you can get the same version of the package I'm debugging with the following command. I have tried to produce a smaller MWE but failed to split it further than what I show below.

]add https://github.com/SymbolicML/DynamicExpressions.jl#c9eaedf63e36a227db702b4ea3257938892447d5

Basically I have this recursive binary tree structure Node{T} (docs). I don't want Zygote to try to walk through the whole tree, and turn it into a tuple — which would be hugely inefficient — so instead, I have this custom NodeTangent type (in src/ChainRules.jl)

struct NodeTangent{T,N<:AbstractExpressionNode{T},A<:AbstractArray{T}} <: AbstractTangent
    tree::N
    gradient::A
end
Base.:+(a::NodeTangent, b::NodeTangent) = NodeTangent(a.tree, a.gradient + b.gradient)
Base.:*(a::Number, b::NodeTangent) = NodeTangent(b.tree, a * b.gradient)
Base.:*(a::NodeTangent, b::Number) = NodeTangent(a.tree, a.gradient * b)
Base.zero(::Union{Type{NodeTangent},NodeTangent}) = ZeroTangent()

I then have a chain rule for evaluation which returns this NodeTangent, defined as follows:

function CRC.rrule(
    ::typeof(eval_tree_array),
    tree::AbstractExpressionNode,
    X::AbstractMatrix,
    operators::OperatorEnum;
    kws...,
)
    primal, complete = eval_tree_array(tree, X, operators; kws...)

    if !complete
        primal .= NaN
    end

    return (primal, complete), EvalPullback(tree, X, operators)
end

# Wrap in struct rather than closure to ensure variables are boxed
struct EvalPullback{N,A,O} <: Function
    tree::N
    X::A
    operators::O
end

# TODO: Preferable to use the primal in the pullback somehow
function (e::EvalPullback)((dY, _))
    _, dX_constants_dY, complete = eval_grad_tree_array(
        e.tree, e.X, e.operators; variable=Val(:both)
    )

    if !complete
        dX_constants_dY .= NaN
    end

    nfeatures = size(e.X, 1)
    dX_dY = @view dX_constants_dY[1:nfeatures, :]
    dconstants_dY = @view dX_constants_dY[(nfeatures + 1):end, :]

    dtree = NodeTangent(
        e.tree, sum(j -> dconstants_dY[:, j] * dY[j], eachindex(dY, axes(dconstants_dY, 2)))
    )

    dX = dX_dY .* reshape(dY, 1, length(dY))

    return (NoTangent(), dtree, dX, NoTangent())
end

This actually works fine. I can get derivatives that are correct and inference seems good:

using DynamicExpressions
using Zygote

const operators = OperatorEnum(;
    binary_operators=[+, -, *],
    unary_operators=[cos],
)
x1 = Node{Float64}(feature=1)
x2 = Node{Float64}(feature=2)

tree = x1 * cos(x2 - 3.2)
julia> Test.@inferred Zygote.gradient(
           t -> eval_tree_array(t, ones(2, 1), operators)[1][1],
           tree
       )
(NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]),)

and it returns a NodeTangent which prevents Zygote from walking the tree.

However, when I then try to use my new Expression type, which is nothing but a Node{T} plus a named tuple of operators and variable names:

struct Expression{T,N<:AbstractExpressionNode{T},D<:NamedTuple} <: AbstractExpression{T,N}
    tree::N
    metadata::Metadata{D}
end

I no longer get this successful inference. Here is the wrapper method of the evaluation:

function eval_tree_array(
    ex::AbstractExpression,
    cX::AbstractMatrix,
    operators::Union{AbstractOperatorEnum,Nothing}=nothing;
    kws...,
)
    return eval_tree_array(get_tree(ex), cX, get_operators(ex, operators); kws...)
end

So it basically just unpacks ex -> ex.tree and ex -> ex.metadata.operators.

Now, say that I try to take the gradient of this instead. Unlike the internal eval_tree_array call, this one I do not define a custom chain rule for (since the wrapper call is simple).

julia> ex = Expression(tree; operators, variable_names=["x1", "x2"])
x1 * cos(x2 - 3.2)

julia> Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
((tree = NodeTangent{Float64, Node{Float64}, Vector{Float64}}(x1 * cos(x2 - 3.2), [-0.8084964038195901]), metadata = nothing),)

julia> Test.@inferred Zygote.gradient(ex -> eval_tree_array(ex, ones(2, 1))[1][1], ex)
ERROR: return type Tuple{@NamedTuple{tree::NodeTangent{Float64, Node{Float64}, Vector{Float64}}, metadata::Nothing}} does not match inferred return type Tuple{Any}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] top-level scope
   @ REPL[25]:1

Even though all eval_tree_array(::AbstractExpression is doing is some getproperty calls before passing to a call – which I know works – inference on this wrapper fails.

Questions:

  1. Any guesses as to what the issue is from?
  2. Do I need to define a custom tangent type for Expression, and what's the actual interface for AbstractTangent?
    • For the record, I did try creating a zero_tangent(::Expression), but this didn't seem to fix the issue. Maybe there's some other function I need to define?
    • Perhaps I need to declare NodeTangent for some other function symbols so that Zygote doesn't try descending at the outermost call? And if so, what methods need to be implemented?
  3. (General) How does one go about debugging type inference issues in Zygote, when type inference on the primal is fine? I can't seem to use Cthulhu.jl effectively though perhaps I am descending the wrong tree.

X-post from https://discourse.julialang.org/t/problems-with-ad-inference-on-wrapper-function/116454?u=milescranmer in the hope I can find the right person to help debug this. This is very important for my work so I can collect as much info as you need, please ask away.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant