diff --git a/src/Zygote.jl b/src/Zygote.jl index 05d0bd80e..7dbe112d0 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,7 +3,7 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular -import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, +import ZygoteRules: ZygoteRules, @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty, literal_getfield, unthunk_tangent using ChainRulesCore diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index a19d7f230..237603380 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -109,10 +109,6 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us @inline wrap_chainrules_output(x) = x @inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks @inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) -# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. -@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing -@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing -@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less @@ -125,6 +121,8 @@ end wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs) + + #= # As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers @inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B} @@ -152,6 +150,7 @@ Convert `dx` from the format Zygote uses internally to differentials types Chain @inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent() @inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent() +@inline wrap_chainrules_input(dxs::AbstractArray{T}) where {T<:AbstractZero} = first(dxs) @inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, dxs) # This produces Tangent{Any} since it does not get to see the primal, `x`. @@ -186,9 +185,12 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`. Safe to apply to arbitrary input. """ @inline function _project(x, dx) - wrap_chainrules_output(ProjectTo(x)(zygote2differential(dx, x))) + differential2zygote(ProjectTo(x)(zygote2differential(dx, x))) end +_project(_, dx::Nothing) = nothing +_project(x::Tuple, dx::Tuple) = map(_project, x, dx) + # Restore splatted arrays _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x))) @@ -350,3 +352,44 @@ z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Nu end z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs + + +""" + differential2zygote(dx) + +Convert input `dx` from ChainRules differential types to the Zygote format. +This is similar to `wrap_chainrules_output(dx)`, but converts zero types, +and recursively converts Tangents. +""" +@inline differential2zygote(@nospecialize(x)) = x +@inline differential2zygote(::AbstractZero) = nothing +@inline differential2zygote(::ChainRulesCore.NotImplemented) = nothing +@inline differential2zygote(x::AbstractThunk) = differential2zygote(unthunk(x)) # For now we are just not going to deal with thunks +for T_outer in (:Tuple, :NamedTuple) + # we create separate methods rather than using a `Union` + an `if` so that we avoid a + # branch that changes output type, because nested AD on that kinda thing makes Zygote less + # than happy. + @eval @inline differential2zygote(x::$T_outer) = map(differential2zygote, x) + @eval @inline function differential2zygote(x::Tangent{<:Any, <:$T_outer}) + # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest + inner = ChainRulesCore.backing(canonicalize(x)) + return differential2zygote(inner) + end +end +# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. +@inline differential2zygote(::Tuple{Vararg{AbstractZero}}) = nothing +@inline differential2zygote(::Tuple{}) = () # Edge case split off from the above method + +differential2zygote(dxs::AbstractArray{<:Number}) = dxs +differential2zygote(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs +differential2zygote(dxs::AbstractArray) = map(differential2zygote, dxs) +differential2zygote(dxs::Dict) = Dict(k => differential2zygote(v) for (k, v) in dxs) + +# Mostly used in rule genfuncs +_iszerotype(T) = T === Nothing || T <: AbstractZero + +# Note: safe piracy to make @adjoint definitions work +ZygoteRules.gradtuple0(x::AbstractZero) = x +ZygoteRules.gradtuple1(x::AbstractZero) = x +ZygoteRules.gradtuple2(x::AbstractZero) = x +ZygoteRules.gradtuple3(x::AbstractZero) = x \ No newline at end of file diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index f350069f4..b5affcb4c 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -42,7 +42,8 @@ tailmemaybe(x::Tuple) = Base.tail(x) @inline pullback(f, args...) = pullback(f, Context(), args...) function pullback(f, cx::AContext, args...) y, back = _pullback(cx, f, args...) - y, Δ -> tailmemaybe(back(Δ)) + wrapped_back(Δ) = tailmemaybe(differential2zygote(back(Δ))) + y, wrapped_back end function pullback(cx::Context, f, args...) ChainRulesCore.ignore_derivatives() do @@ -95,7 +96,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d function gradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - isnothing(grad) ? nothing : map(_project, args, grad) + return _project(args, grad) end # Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy! @@ -131,8 +132,7 @@ julia> res.grad[w] function withgradient(f, args...) y, back = pullback(f, args...) grad = back(sensitivity(y)) - results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad) - (val=y, grad=results) + (val=y, grad=_project(args, grad)) end # Param-style wrappers @@ -184,7 +184,7 @@ Params(xs::Tuple) = Params(collect(xs)) Base.in(x, ps::Params) = x in ps.params -Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params) +_project(::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params) function Base.union!(ps::Params, itrs...) foreach(itr -> foreach(x -> push!(ps, x), itr), itrs) diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index 333323e83..27760c81e 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -6,6 +6,7 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk, @inline tuple_va(N, xs) = xs @inline tuple_va(N, x, xs...) = (x, tuple_va(N, xs...)...) @inline tuple_va(::Val{N}, ::Nothing) where N = ntuple(_ -> nothing, Val(N)) +@inline tuple_va(::Val{N}, x::AbstractZero) where N = ntuple(_ -> x, Val(N)) iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n) diff --git a/src/compiler/show.jl b/src/compiler/show.jl index 8e6797f15..7505af933 100644 --- a/src/compiler/show.jl +++ b/src/compiler/show.jl @@ -9,4 +9,7 @@ function funcname(T) end Base.show(io::IO, j::Pullback{S}) where S = print(io, "∂($(funcname(S.parameters[1])))") +function Base.show(io::IO, P::Type{<:Pullback{S}}) where S + @isdefined(S) ? print(io, "Pullback{", S, ", ...}") : print(io, "Pullback{S, T}") +end diff --git a/src/lib/array.jl b/src/lib/array.jl index a855f8946..bf8c9f306 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -39,24 +39,6 @@ end @adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,) @adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,) -@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds) - -@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds) - -∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin - if inds isa NTuple{N,Int} && T <: Number - dx = OneElement(dy, inds, axes(x)) - elseif inds isa NTuple{<:Any, Integer} - dx = _zero(x, typeof(dy)) - dx[inds...] = dy - else - dx = _zero(x, eltype(dy)) - dxv = view(dx, inds...) - dxv .= accum.(dxv, _droplike(dy, dxv)) - end - return (_project(x, dx), map(_->nothing, inds)...) -end - """ OneElement(val, ind, axes) <: AbstractArray @@ -247,10 +229,10 @@ reconstruct_if_dict(x̄, _keys::Nothing) = x̄ function reconstruct_if_dict(x̄, _keys) # This reverses `collect_if_dict`, which returns `_keys::Nothing` if x is not a Dict - @assert x̄ isa AbstractVector{<:Union{Nothing, NamedTuple{(:first,:second)}}} + @assert x̄ isa AbstractVector # {<:Union{Nothing, AbstractZero, NamedTuple{(:first,:second)}}} # we don't compute gradients with respect to keys # @assert all(x -> x === nothing || x[1] == 0 || x[1] === nothing, x̄) - d̄ = Dict(k => isnothing(x) ? nothing : x[2] for (x, k) in zip(x̄, _keys)) + d̄ = Dict(k => x === nothing || x isa AbstractZero ? x : x[2] for (x, k) in zip(x̄, _keys)) return d̄ end @@ -296,8 +278,9 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x) nd = _ndims(xs[n]) dims = ntuple(i -> i prod(f.(xs)), f, xs) end -@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),) -@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),) +@adjoint function real(x::AbstractArray) + real_array_pullback(r̄::AbstractZero) = (r̄,) + real_array_pullback(r̄) = (real(r̄),) + return real(x), real_array_pullback +end +@adjoint function conj(x::AbstractArray) + conj_array_pullback(r̄::AbstractZero) = (r̄,) + conj_array_pullback(r̄) = (conj(r̄),) + return conj(x), conj_array_pullback +end @adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),) @@ -445,6 +436,7 @@ _symmetric_back(Δ::LowerTriangular, uplo) = collect(uplo == 'U' ? transpose(Δ) @adjoint function Symmetric(A::AbstractMatrix, uplo=:U) S = Symmetric(A, uplo) + back(Δ::AbstractZero) = (Δ, nothing) back(Δ::AbstractMatrix) = (_symmetric_back(Δ, S.uplo), nothing) back(Δ::NamedTuple) = (_symmetric_back(Δ.data, S.uplo), nothing) return S, back @@ -469,15 +461,23 @@ end @adjoint function LinearAlgebra.Hermitian(A::AbstractMatrix, uplo=:U) H = Hermitian(A, uplo) + back(Δ::AbstractZero) = (Δ, nothing) back(Δ::AbstractMatrix) = (_hermitian_back(Δ, H.uplo), nothing) back(Δ::NamedTuple) = (_hermitian_back(Δ.data, H.uplo), nothing) return H, back end -@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A), - Δ -> (nothing, convert(S, Δ),) -@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A), - Δ -> (convert(S, Δ),) +@adjoint function convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} + convert_Array_HermOrSym_callback(Δ::AbstractZero) = (nothing, Δ) + convert_Array_HermOrSym_callback(Δ) = (nothing, convert(S, Δ)) + return convert(R, A), convert_Array_HermOrSym_callback +end + +@adjoint function Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} + Matrix_HermOrSym_pullback(Δ::AbstractZero) = (Δ,) + Matrix_HermOrSym_pullback(Δ) = (convert(S, Δ),) + return Matrix(A), Matrix_HermOrSym_pullback +end @adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix) X = lyap(A, C) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 4508c3ca2..a33f94a79 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -284,6 +284,7 @@ end @inline function _broadcast_forward(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> value(x), out) + bc_fwd_back(ȳ::AbstractZero) = ȳ function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> y1 * partials(o1,i), ȳ, out)) @@ -297,6 +298,7 @@ end @inline function _broadcast_forward(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + bc_fwd_back(ȳ::AbstractZero) = ȳ function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*partials(real(o1),i) + imag(y1)*partials(imag(o1), i)), ȳ, out)) @@ -311,6 +313,7 @@ end @inline function _broadcast_forward_complex(::Type{<:Dual}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> value(x), out) + bc_fwd_back(ȳ::AbstractZero) = ȳ function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(partials(o1, i), partials(o1, i+N)), ȳ, out)) @@ -335,6 +338,7 @@ end @inline function _broadcast_forward_complex(::Type{<:Complex}, out, args::Vararg{Any, N}) where {N} valN = Val(N) y = broadcast(x -> Complex(value(real(x)), value(imag(x))), out) + bc_fwd_back(ȳ::AbstractZero) = ȳ function bc_fwd_back(ȳ) dargs = ntuple(valN) do i unbroadcast(args[i], broadcast((y1, o1) -> _adjoint_complex(N, y1, o1, i), ȳ, out)) diff --git a/src/lib/lib.jl b/src/lib/lib.jl index eaa49ada2..717e9fa92 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -166,7 +166,11 @@ end first(xs), Δ -> ((Δ, drest...),) end -@adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) +@adjoint function Base.tail(xs::Tuple) + Tuple_tail_pullback(x̄s::AbstractZero) = (x̄s,) + Tuple_tail_pullback(x̄s) = ((nothing, x̄s...),) + return tail(xs), Tuple_tail_pullback +end _empty(x) = length(x) _empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x) @@ -202,11 +206,12 @@ if VERSION >= v"1.4.0-DEV.304" @adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) y, back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) - y, function (Δ) + function _apply_iterate_pullback(Δ) Δ = back(Δ) - Δ === nothing ? nothing : - (nothing, first(Δ), unapply(st, Base.tail(Δ))...) + Δ isa Union{Nothing,AbstractZero} && return Δ + return (nothing, first(Δ), unapply(st, Base.tail(Δ))...) end + return y, _apply_iterate_pullback end end @@ -229,11 +234,14 @@ end val = getfield(x, f) function back(Δ) accum_param(__context__, val, Δ) === nothing && return + # Const properties on modules are considered non-differentiable + x isa Module && isconst(x, f) && return if isimmutable(x) dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...) (_project(x, dx), nothing) else dx = grad_mut(__context__, x) + # @show dx dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...) return (dx,nothing) end @@ -305,24 +313,28 @@ end end # TODO captured mutables + multiple calls to `back` -@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} - !ismutabletype(T) && Δ == Nothing && return :nothing - Δ = G == Nothing ? :Δ : - Δ <: RefValue ? :(back.g[]) : - :(accum(back.g[], Δ)) +@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G} + !ismutabletype(T) && _iszerotype(Δ) && return :Δ + Δ = if _iszerotype(G) + :Δ + elseif Δ <: RefValue + :(back.g[]) + else + :(accum(back.g[], Δ)) + end quote x̄ = $Δ - $(G == Nothing || :(back.g[] = nt_nothing($Δ))) + $(_iszerotype(G) || :(back.g[] = nt_nothing($Δ))) (nothing, $(map(f -> :(x̄.$f), fieldnames(T))...)) end end -@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} - !ismutabletype(T) && Δ == Nothing && return :nothing - Δ = G == Nothing ? :Δ : :(back.g) +@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue,AbstractZero}) where {T,G} + !ismutabletype(T) && _iszerotype(Δ) && return :Δ + Δ = _iszerotype(G) ? :Δ : :(back.g) quote x̄ = $Δ - $(G == Nothing || :($Δ = nt_nothing($Δ))) + $(_iszerotype(G) || :($Δ = nt_nothing($Δ))) (nothing, ($(map(f -> :(x̄.$f), fieldnames(T))...),)) end end diff --git a/test/chainrules.jl b/test/chainrules.jl index 7bd66a4d2..d7a89cf6a 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -271,8 +271,8 @@ using Zygote: ZygoteRuleConfig @test Zygote.gradient(f_notimplemented, 0.1) === (nothing,) @test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,) if isdefined(Base, :only) - @test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,) - @test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,) + @test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === ((nothing,),) + @test_broken Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,) end end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index b170aa045..f396c1718 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,5 +1,5 @@ using Zygote, Test, Random, LinearAlgebra, Statistics, SparseArrays, FillArrays, - AbstractFFTs, FFTW, Distances + AbstractFFTs, FFTW, Distances, ChainRulesCore using Zygote: gradient using Base.Broadcast: broadcast_shape using Distributed: pmap, CachingPool, workers @@ -38,6 +38,7 @@ _joinreim(A) = A function _dropimaggrad(A) back(Δ) = real(Δ) + back(Δ::AbstractZero) = Δ back(Δ::Nothing) = nothing return Zygote.hook(back, A) end @@ -174,11 +175,11 @@ end # Ensure that nothings work with numeric types. _, back = Zygote.pullback(getindex, randn(4), [1]) - @test back([nothing]) == (zeros(4), nothing) + @test back([nothing]) === nothing # Ensure that nothings work with non-numeric types. _, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1]) - @test back([nothing]) == (nothing, nothing) + @test back([nothing]) === nothing end @testset "view" begin diff --git a/test/lib/array.jl b/test/lib/array.jl index d02e9f9d3..131eeb44b 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -24,8 +24,7 @@ end k = 2 i = findfirst(p -> p[1] == k, collect(d)) g = gradient(d -> collect(d)[i][2], d)[1] - @test g isa Dict{Int64, <:Union{Nothing, Int64}} - @test g[k] == 1 + @test g == Dict(k => 1, 1 => nothing) g = gradient(d -> sum(v^2 for (_,v) in collect(d)), d)[1] @test g isa Dict{Int,Int} diff --git a/test/utils.jl b/test/utils.jl index cb11437cf..4ece9c62b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,13 +3,8 @@ using Zygote: hessian_dual, hessian_reverse @testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse] - if hess == hessian_dual - @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] - @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version - else - @test_broken hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] # can't differentiate ∇getindex - @test_broken hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] - end + @test hess(x -> x[1]*x[2], randn(2)) ≈ [0 1; 1 0] + @test hess(((x,y),) -> x*y, randn(2)) ≈ [0 1; 1 0] # original docstring version @test hess(x -> sum(x.^3), [1 2; 3 4]) ≈ Diagonal([6, 18, 12, 24]) @test hess(sin, pi/2) ≈ -1