diff --git a/src/Zygote.jl b/src/Zygote.jl index a092c1f11..5489bb591 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -5,9 +5,9 @@ using LinearAlgebra: copytri!, AbstractTriangular using ArrayLayouts: MemoryLayout, AbstractColumnMajor import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty -using ZygoteRules: differential2legacy, legacy2differential, legacytype_warn, gradtuple1 +using ZygoteRules: ZygoteRules, differential2legacy, legacy2differential, legacytype_warn, diffgradtuple1 -using ChainRules: ChainRules, rrule, unthunk, AbstractZero, Zero, DoesNotExist +using ChainRules: ChainRules, rrule, unthunk, AbstractZero, Zero, DoesNotExist, Composite using IRTools using MacroTools, Requires using MacroTools: @forward diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 679554964..b94735c76 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -43,6 +43,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us # Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing. @inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = Zero() @inline wrap_chainrules_output(x::ChainRules.AbstractZero) = x +#= for T_outer in (:Tuple, :NamedTuple) # we create separate methods rather than using a `Union` + an `if` so that we avoid a # branch that changes output type, because nested AD on that kinda thing makes Zygote less @@ -52,6 +53,7 @@ for T_outer in (:Tuple, :NamedTuple) convert($T_outer, xp) end end +=# """ wrap_chainrules_input(x) @@ -59,8 +61,9 @@ end Convert `x` from the format Zygote uses internally to differentials types ChainRules uses. """ @inline wrap_chainrules_input(x) = x -@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero() +@inline wrap_chainrules_input(::Nothing) = (legacytype_warn(Nothing); return ChainRules.Zero()) @inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) + legacytype_warn(typeof(xs)) xp = map(wrap_chainrules_input, xs) ChainRules.Composite{Any, typeof(xp)}(xp) end @@ -77,7 +80,7 @@ end @inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) -@inline (s::ZBack)(::Nothing) = (legacytype_warn(); return Zero()) +@inline (s::ZBack)(::Nothing) = (legacytype_warn(Nothing); return Zero()) """ chain_rrule(f, args...) diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 15659d58e..80683daee 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -41,9 +41,28 @@ tailmemaybe(::Nothing) = nothing tailmemaybe(x::Tuple) = Base.tail(x) function pullback(f, args...) - y, back = _pullback(f, args...) - y, Δ -> tailmemaybe(differential2legacy(back(legacy2differential(Δ)))) + y, _back = _pullback(f, args...) + y, Δ -> tailmemaybe(differential2legacy(_back(ZygoteRules.l2d(Δ, y)))) +end + +#== +function _pullback(__context__::AContext, ::typeof(pullback), f, args...) + lesser_y, _lesser_back = _pullback(__context__, f, args...) + lesser_back = Δ -> tailmemaybe(differential2legacy(_lesser_back(ZygoteRules.l2d(Δ, lesser_y)))) + @show f + greater_y = lesser_y, lesser_back + @show 2 + function greater_back(greater_Δ) + @show 3 + Δlesser_y, Δlesser_back = greater_Δ + greater_y_again, second_back = _pullback(__context__, _lesser_back, Δlesser_y) + Δf_and_Δargs = second_back(greater_Δ) + Δf, Δargs = Iterators.peel(Δf_and_Δargs) + (Zero(), Δf, Δargs...) + end + return greater_y, greater_back end +==# sensitivity(y::Number) = one(y) sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.") @@ -169,12 +188,12 @@ end function pullback(f, ps::Params) cx = Context() - y, back = _pullback(cx, f) + y, _back = _pullback(cx, f) y, function (Δ) for p in ps cache(cx)[p] = nothing end - differential2legacy(back(legacy2differential(Δ))) # has non-local effects via accum (?) + differential2legacy(_back(ZygoteRules.l2d(Δ, y))) # has non-local effects via accum (?) Grads(cx.cache, ps) # TODO make a copy end end diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index c33369ae7..c289dad79 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -42,7 +42,7 @@ end end if g == nothing # No IR found Δ <: AbstractZero && return :(Δ) - Δ == Nothing && (legacytype_warn(); return :(DoesNotExist())) + Δ == Nothing && (legacytype_warn(Nothing); return :(DoesNotExist())) return :(error("Non-differentiable function $(repr(j.t[1]))")) end meta, _, back = g diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl index d3c683919..010669b05 100644 --- a/src/compiler/reverse.jl +++ b/src/compiler/reverse.jl @@ -12,7 +12,7 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk, iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n) gradindex(x, i) = x[i] -gradindex(::Nothing, i) = (legacytype_warn(); return DoesNotExist()) +gradindex(::Nothing, i) = (legacytype_warn(Nothing); return DoesNotExist()) gradindex(x::AbstractZero, i) = x xgetindex(x, i...) = xcall(Base, :getindex, x, i...) xgradindex(x, i) = xcall(Zygote, :gradindex, x, i) @@ -237,12 +237,10 @@ function adjoint(pr::Primal) for v in reverse(keys(b)) ex = b[v].expr if haskey(pr.pullbacks, v) - g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), - line = b[v].line)) + g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), line = b[v].line)) for (i, x) in enumerate(ex.args) x isa Variable || continue - grad(x, push!(rb, stmt(xgradindex(g, i), - line = b[v].line))) + grad(x, push!(rb, stmt(xgradindex(g, i), line = b[v].line))) end elseif ex isa Core.PiNode grads[ex.val] = grads[v] diff --git a/src/lib/array.jl b/src/lib/array.jl index 1a916b56f..2d40f3056 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -164,6 +164,16 @@ function unzip(tuples) _unzip(tuples, Val(N)) end +# if we have an iterator, unzip may (will) lose track of what the outer type should be +# First arg is the primal iterator type to reconstruct to +reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:Union{UnitRange, StepRange} = diff_iter +reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:AbstractArray = convert(T, diff_iter) +reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:Tuple = Composite{T}(diff_iter...) +reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:NamedTuple = Composite{T}(;NamedTuple{fieldnames(T)}(diff_iter)...) + +# TODO: piracy, move to ChainRulesCore.jl +Base.convert(::Type{T}, x::AbstractZero) where T <: Real = zero(T) + # Reverse iteration order when ∇map is applied to vector, # needed for stateful functions. # See https://github.com/FluxML/Flux.jl/issues/1209 @@ -172,39 +182,53 @@ _tryreverse(m, backs, Δ) = backs, Δ function _tryreverse(m::typeof(map), backs, Δ::Union{AbstractVector, Tuple}) return reverse(backs), reverse(Δ) end +function _tryreverse(m::typeof(map), backs, Δ::Composite) + return reverse(backs), _tryreverse(m, Δ) +end _tryreverse(m, x) = x _tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x) +_tryreverse(m::typeof(map), c::Composite) = Composite{typeof(reverse(c.backing)), typeof(reverse(c.backing))}(reverse(c.backing)) for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap),(:vmap,:∇vmap)] @eval function $∇mapfunc(cx, f, args...) ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...) if isempty(ys_and_backs) - ys_and_backs, _ -> nothing + return ys_and_backs, _ -> Zero() else ys, backs = unzip(ys_and_backs) ys, function (Δ) # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful. Δf_and_args_zipped = $mapfunc( - (f, δ) -> differential2legacy(f(legacy2differential(δ))), + (f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)... ) Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped)) Δf = reduce(accum, Δf_and_args[1]) - (Δf, Δf_and_args[2:end]...) + Δargs_raw = Δf_and_args[2:end] + Δargs = ntuple(length(args)) do i + @show typeof(args[i]) + @show Δargs_raw[i] + return reconstruct_differential_from_iterator(typeof(args[i]), Δargs_raw[i]) + end + return (Δf, Δargs...) end end end - @eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...) - $∇mapfunc(__context__, f, args...) + @eval function _pullback(__context__::AContext, ::typeof($mapfunc), f, args::Union{AbstractArray,Tuple}...) + ys, _f_back = $∇mapfunc(__context__, f, args...) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = diffgradtuple1(_f_back(Δ)) + return ys, _back end end function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator) - y, back = ∇map(cx, g.f, g.iter) + y, _back = ∇map(cx, g.f, g.iter) y, function (ȳ) - f̄, x̄ = legacy2differential(back(differential2legacy(ȳ))) - (DoesNotExist(), (f = f̄, iter = x̄),) + f̄, x̄ = _back(ȳ) + (DoesNotExist(), Composite{typeof(g)}(f = f̄, iter = x̄),) end end @@ -456,7 +480,7 @@ end Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B) return Y, function(Ȳ) Ā_factors, B̄ = back(Ȳ) - return ((uplo=nothing, status=nothing, factors=Ā_factors), B̄) + return ((factors=Ā_factors, uplo=nothing, info=nothing), B̄) end end @@ -513,7 +537,7 @@ end C = cholesky(Σ, check = check) return C, Δ::NamedTuple -> begin issuccess(C) || throw(PosDefException(C.info)) - return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing + return (Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)),) end end @@ -752,30 +776,30 @@ end # Various sensitivities for `literal_getproperty`, depending on the 2nd argument. @adjoint function literal_getproperty(C::Cholesky, ::Val{:uplo}) return literal_getproperty(C, Val(:uplo)), function(Δ) - return ((uplo=nothing, info=nothing, factors=nothing),) + return ((factors=nothing, uplo=nothing, info=nothing), nothing) end end -@adjoint function literal_getproperty(C::Cholesky, ::Val{:info}) +@adjoint function literal_getproperty(C::Cholesky, ::Val{:info}) # TODO make sure these work by changing the @adjoint macro return literal_getproperty(C, Val(:info)), function(Δ) - return ((uplo=nothing, info=nothing, factors=nothing),) + return ((factors=nothing, uplo=nothing, info=nothing), nothing) end end @adjoint function literal_getproperty(C::Cholesky, ::Val{:U}) return literal_getproperty(C, Val(:U)), function(Δ) Δ_factors = C.uplo == 'U' ? UpperTriangular(Δ) : LowerTriangular(copy(Δ')) - return ((uplo=nothing, info=nothing, factors=Δ_factors),) + return ((factors=Δ_factors, uplo=nothing, info=nothing), nothing) end end @adjoint function literal_getproperty(C::Cholesky, ::Val{:L}) return literal_getproperty(C, Val(:L)), function(Δ) Δ_factors = C.uplo == 'L' ? LowerTriangular(Δ) : UpperTriangular(copy(Δ')) - return ((uplo=nothing, info=nothing, factors=Δ_factors),) + return ((factors=Δ_factors, uplo=nothing, info=nothing), nothing) end end @adjoint function logdet(C::Cholesky) return logdet(C), function(Δ) - return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) + return ((factors=Diagonal(2 .* Δ ./ diag(C.factors)), uplo=nothing, info=nothing),) end end diff --git a/src/lib/base.jl b/src/lib/base.jl index 1e5344fdb..6c1d9705e 100644 --- a/src/lib/base.jl +++ b/src/lib/base.jl @@ -53,7 +53,7 @@ grad_mut(ch::Channel) = Channel(ch.sz_max) @adjoint! function put!(ch::Channel, x) put!(ch, x), function (ȳ) x̄ = take!(grad_mut(__context__, ch)) - (nothing, accum(x̄, ȳ), nothing) + return (nothing, accum(x̄, ȳ)) end end @@ -64,36 +64,36 @@ end end end -@adjoint! function Task(f) +function _pullback(__context__::AContext, ::Type{<:Task}, f) t = Task(f) t.code = function () - y, back = _pullback(__context__, f) - cache(__context__)[t] = Task(back) + y, _back = _pullback(__context__, f) + cache(__context__)[t] = Task(_back) # when `fetch`ed, this returns a tuple: (f̄,) return y end - t, _ -> fetch(cache(__context__)[t]) + return t, _ -> (DoesNotExist(), first(fetch(cache(__context__)[t]))) end -function runadjoint(cx, t, ȳ = nothing) +function runadjoint(cx, t, ȳ = DoesNotExist()) t̄ = cache(cx)[t] f = t̄.code - t̄.code = () -> differential2legacy(f(legacy2differential(ȳ))) + t̄.code = () -> f(ȳ) @static if VERSION > v"1.3-" t̄.sticky = t.sticky end schedule(t̄) end -@adjoint! function wait(t::Task) - wait(t), _ -> (runadjoint(__context__, t); nothing) +function _pullback(__context__::AContext, ::typeof(wait), t::Task) + wait(t), _ -> (runadjoint(__context__, t); DoesNotExist()) end -@adjoint! function fetch(t::Task) - fetch(t), ȳ -> (runadjoint(__context__, t, ȳ); nothing) +function _pullback(__context__::AContext, ::typeof(fetch), t::Task) + fetch(t), ȳ -> (runadjoint(__context__, t, ȳ); DoesNotExist()) end -@adjoint! function Base.sync_end(refs) - Base.sync_end(refs), _ -> foreach(t -> runadjoint(__context__, t), refs) +function _pullback(__context__::AContext, ::typeof(Base.sync_end), refs) + Base.sync_end(refs), _ -> (foreach(t -> runadjoint(__context__, t), refs); DoesNotExist()) end # Make @sync work diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 312a5e176..5440a72f1 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -35,8 +35,8 @@ accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims) # Work around reducedim_init issue # https://github.com/JuliaLang/julia/issues/31427 -accum_sum(xs::Nothing; dims = :) = nothing -accum_sum(xs::AbstractArray{Nothing}; dims = :) = nothing +accum_sum(xs::AbstractZero; dims = :) = xs +accum_sum(xs::AbstractArray{<:AbstractZero}; dims = :) = Zero() accum_sum(xs::AbstractArray{<:Number}; dims = :) = sum(xs, dims = dims) accum_sum(xs::AbstractArray{<:AbstractArray{<:Number}}; dims = :) = sum(xs, dims = dims) accum_sum(xs::Number; dims = :) = xs @@ -57,7 +57,7 @@ unbroadcast(x::Number, x̄) = accum_sum(x̄) unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),) unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),) -unbroadcast(x::AbstractArray, x̄::Nothing) = nothing +unbroadcast(x::AbstractArray, x̄::AbstractZero) = x̄ # Split Reverse Mode # ================== @@ -68,18 +68,33 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing Numeric{T<:Number} = Union{T,AbstractArray{<:T}} -@adjoint broadcasted(::typeof(+), xs::Numeric...) = - broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(+), xs::Numeric...) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), map(x -> unbroadcast(x, Δ), xs)...) + return broadcast(+, xs...), _back +end -@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y, - Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ)) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(-), x::Numeric, y::Numeric) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ), -unbroadcast(y, Δ)) + return x .- y, _back +end -@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y, - z̄ -> (nothing, unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x))) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(*), x::Numeric, y::Numeric) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x))) + x.*y, _back +end -@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(/), x::Numeric, y::Numeric) res = x ./ y - res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, -Δ .* conj.(res ./ y))) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, -Δ .* conj.(res ./ y))) + res, _back end @adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p @@ -127,29 +142,34 @@ end _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...)) _get(x::Tuple, i) = x[i] -_get(::Nothing, i) = nothing -collapse_nothings(xs::Vector{Nothing}) = nothing -collapse_nothings(xs) = xs +_get(x::AbstractZero, i) = x +collapse_zeros(xs::Vector{<:AbstractZero}) = DoesNotExist() +collapse_zeros(xs) = xs -@adjoint function broadcasted(::AbstractArrayStyle, f, args...) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::AbstractArrayStyle, f, args...) len = inclen(args) y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...) y = map(x -> x[1], y∂b) ∂b = map(x -> x[2], y∂b) - y, function (ȳ) - dxs_zip = map((∂b, ȳ) -> differential2legacy(∂b(legacy2differential(ȳ))), ∂b, ȳ) - dxs = collapse_nothings.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len)) - (nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + function _back(ȳ) + dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ) + dxs = collapse_zeros.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len)) + return (DoesNotExist(), DoesNotExist(), accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) end + return y, _back end -@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...) - len = inclen(args) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::AbstractArrayStyle{0}, f, args...) y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...) - y, function (ȳ) - dxs = differential2legacy(∂b(legacy2differential(ȳ))) - (nothing, dxs...) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + function _back(ȳ) + dxs = ∂b(ȳ) + return (DoesNotExist(), DoesNotExist(), dxs...) end + return y, _back end # Use the `map` adjoint in this special case, which is the same but applies @@ -162,9 +182,10 @@ end # end @adjoint! function (b::typeof(broadcast))(f, args...) - _pb = _pullback(__context__, broadcasted, f, args...) - Δ -> differential2legacy(_pb(legacy2differential(Δ))) + y, _back = _pullback(__context__, broadcasted, f, args...) + y, Δ -> differential2legacy(_back(legacy2differential(Δ, y))) end + # Forward Mode (mainly necessary for CUDA) import ForwardDiff @@ -190,7 +211,9 @@ end out = dual_function(f).(args...) eltype(out) <: Dual || return (out, _ -> nothing) y = map(x -> x.value, out) - _back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out)) + _back(ȳ, i) = differential2legacy( # unbroadcast returns differential types + unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out)) + ) back(ȳ) = ntuple(i -> _back(ȳ, i), N) return y, back end diff --git a/src/lib/lib.jl b/src/lib/lib.jl index b5d18c233..03ce85bc8 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -42,16 +42,17 @@ end isbitstype(x) && return :(Δ) quote if haskey(cache(cx), x) - cache(cx)[x] = accum(cache(cx)[x],Δ) - return + cache(cx)[x] = accum(cache(cx)[x], Δ) + return Zero() # we have already accumulated it into the cache so nothing left to accumulate else - return Δ + return Δ # we are not accumulating it into the cache so it must be accumulated later end end end function accum_global(cx::Context, ref, x̄) - (x̄ === nothing || isconst(ref.mod, ref.name)) && return + x̄ isa Nothing && legacytype_warn(Nothing) + (x̄ isa AbstractZero || isconst(ref.mod, ref.name)) && return x̄ gs = cache(cx) gs[ref] = accum(get(gs, ref, Zero()), x̄) return @@ -59,13 +60,23 @@ end unwrap(x) = x -@adjoint unwrap(x) = unwrap(x), x̄ -> (accum_param(__context__, x, x̄),) +function _pullback(__context__::AContext, ::typeof(unwrap), x) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(x̄) = diffgradtuple1(accum_param(__context__, x, x̄)) + return unwrap(x), _back +end unwrap(ref, x) = x -@adjoint unwrap(ref, x) = unwrap(x), function (x̄) - accum_global(__context__, ref, x̄) - (accum_param(__context__, x, x̄),) +function _pullback(__context__::AContext, ::typeof(unwrap), ref, x) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + function _back(x̄) + accum_global(__context__, ref, x̄) + return diffgradtuple1((accum_param(__context__, x, x̄),)) + end + return unwrap(x), _back end function global_set(ref, val) @@ -88,30 +99,39 @@ using Base: tail @adjoint tuple(xs...) = xs, identity -@adjoint function literal_getindex(xs::NTuple{N,Any}, ::Val{i}) where {N,i} +function _pullback(__context__::AContext, ::typeof(literal_getindex), xs::NTuple{N,Any}, ::Val{i}) where {N,i} val = xs[i] - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return - return ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing + function _back(Δ) + Δ_accum = accum_param(__context__, val, Δ) + Δ_accum isa AbstractZero && return Δ_accum + nt = ntuple(j -> i == j ? Δ : Zero(), Val(N)) + return diffgradtuple1((Composite{typeof(xs), typeof(nt)}(nt), DoesNotExist())) end - val, back + val, _back end -@adjoint function getindex(xs::NTuple{N,Any}, i::Integer) where N +function _pullback(__context__::AContext, ::typeof(getindex), xs::NTuple{N,Any}, i::Integer) where N val = xs[i] - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return - return ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing + function _back(Δ) + Δ_accum = accum_param(__context__, val, Δ) + Δ_accum isa AbstractZero && return Δ_accum + nt = ntuple(j -> i == j ? Δ : Zero(), Val(N)) + return diffgradtuple1((Composite{typeof(xs), typeof(nt)}(nt), DoesNotExist())) end - return val, back + val, _back end -@adjoint getindex(xs::NTuple{N,Any}, r::AbstractUnitRange) where N = - (xs[r], Δ -> (ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : nothing, Val(N)), nothing)) +function _pullback(__context__::AContext, ::typeof(getindex), xs::NTuple{N,Any}, r::AbstractUnitRange) where N + function _back(Δ) + t = ntuple(j -> j in r ? Δ[findfirst(isequal(j), r)] : Zero(), Val(N)) + return diffgradtuple1((Composite{typeof(xs), typeof(t)}(t), DoesNotExist())) + end + return xs[r], _back +end function _pullback(cx::Context, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}) where i y, b = _pullback(cx, literal_getindex, xs, Val(i)) - back(::Nothing) = (legacytype_warn(); return Zero()) + back(::Nothing) = (legacytype_warn(Nothing); return Zero()) back(x::AbstractZero) = x back(ȳ) = b(ȳ[1]) (y, i+1), back @@ -119,18 +139,28 @@ end function _pullback(cx::Context, ::typeof(literal_indexed_iterate), xs::Tuple, ::Val{i}, st) where i y, b = _pullback(cx, literal_getindex, xs, Val(i)) - back(::Nothing) = (legacytype_warn(); return Zero()) + back(::Nothing) = (legacytype_warn(Nothing); return Zero()) back(x::AbstractZero) = x back(ȳ) = (b(ȳ[1])..., Zero()) (y, i+1), back end # Needed for iteration lowering -@adjoint Core.getfield(xs::NTuple{N,Any}, i::Integer) where N = - (xs[i], Δ -> (ntuple(j -> i == j ? Δ : nothing, Val(N)), nothing)) +function _pullback(cx::Context, ::typeof(Core.getfield), xs::NTuple{N,Any}, i::Integer) where N + function _back(Δ) + t = ntuple(j -> i == j ? Δ : Zero(), Val(N)) + return diffgradtuple1((Composite{typeof(xs), typeof(t)}(t), DoesNotExist())) + end + return xs[i], _back +end -@adjoint Core.getfield(xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} = - (xs[i], Δ -> (NamedTuple{K}(ntuple(j -> i == j ? Δ : nothing, Val(N))), nothing)) +function _pullback(cx::Context, ::typeof(Core.getfield), xs::NamedTuple{K,<:NTuple{N,Any}}, i::Integer) where {K,N} + function _back(Δ) + t = NamedTuple{K}(ntuple(j -> i == j ? Δ : Zero(), Val(N))) + return diffgradtuple1((Composite{typeof(xs), typeof(t)}(t), DoesNotExist())) + end + return xs[i], _back +end @adjoint function Base.first(xs::Tuple) drest = map(_->nothing, tail(xs)) @@ -140,7 +170,7 @@ end @adjoint Base.tail(xs::Tuple) = tail(xs), x̄s -> ((nothing, x̄s...),) _empty(x) = length(x) -_empty(x::Union{Tuple,NamedTuple}) = map(_->nothing, x) +_empty(x::Union{Tuple,NamedTuple}) = map(_->DoesNotExist(), x) _unapply(t::Integer, xs) = xs[1:t], xs[t+1:end] _unapply(t, xs) = first(xs), tail(xs) @@ -159,29 +189,39 @@ end unapply(t, xs) = _unapply(t, xs)[1] -@adjoint! function Core._apply(f, args...) - y, back = Core._apply(_pullback, (__context__, f), args...) +function _pullback(__context__::AContext, ::typeof(Core._apply), f, args...) + y, _back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) - y, function (Δ) - Δ = differential2legacy(back(legacy2differential(Δ))) - if Δ === nothing - return nothing + y, function (ȳ) + Δ = _back(ȳ) + if Δ isa AbstractZero + return Δ else - (first(Δ), unapply(st, Base.tail(Δ))...) + tuple_grads = unapply(st, Base.tail(Δ)) + composite_grads = ntuple( + i -> Composite{typeof(tuple_grads[i]), typeof(tuple_grads[i])}(tuple_grads[i]), + length(tuple_grads) + ) + return (DoesNotExist(), first(Δ), composite_grads...) end end end if VERSION >= v"1.4.0-DEV.304" - @adjoint! function Core._apply_iterate(::typeof(iterate), f, args...) - y, back = Core._apply(_pullback, (__context__, f), args...) + function _pullback(__context__::AContext, ::typeof(Core._apply_iterate), f, args...) + y, _back = Core._apply(_pullback, (__context__, f), args...) st = map(_empty, args) - y, function (Δ) - Δ = differential2legacy(back(legacy2differential(Δ))) - if Δ === nothing - return nothing + y, function (ȳ) + Δ = _back(ȳ) + if Δ isa AbstractZero + return Δ else - (nothing, first(Δ), unapply(st, Base.tail(Δ))...) + tuple_grads = unapply(st, Base.tail(Δ)) + composite_grads = ntuple( + i -> Composite{typeof(tuple_grads[i]), typeof(tuple_grads[i])}(tuple_grads[i]), + length(tuple_grads) + ) + return (DoesNotExist(), DoesNotExist(), first(Δ), composite_grads...) end end end @@ -197,26 +237,26 @@ function deref!(x::Ref) return d end -@generated nt_nothing(x) = Expr(:tuple, [:($f=nothing) for f in fieldnames(x)]...) - @generated nt_zero(x) = Expr(:tuple, [:($f=Zero()) for f in fieldnames(x)]...) @generated pair(::Val{k}, v) where k = :($k = v,) -@adjoint function literal_getproperty(x, ::Val{f}) where f - val = getproperty(x, f) - function back(Δ) - accum_param(__context__, val, Δ) === nothing && return - if isimmutable(x) - ((;nt_nothing(x)...,pair(Val(f), Δ)...), nothing) - else - dx = grad_mut(__context__, x) - dx[] = (;dx[]...,pair(Val(f),accum(getfield(dx[], f), Δ))...) - return (dx,nothing) +function _pullback(__context__::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f + val = getproperty(x, f) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + function _back(Δ) + accum_param(__context__, val, Δ) isa AbstractZero && return Zero() + if isimmutable(x) + return (DoesNotExist(), Composite{typeof(x)}(;f => Δ), DoesNotExist()) + else + dx = grad_mut(__context__, x) + dx[] += Composite{typeof(x)}(;f => Δ) # is += the right thing to do? (a=1, b=2, :a=>3) gives (a = 3, b = 2) + return (DoesNotExist(), dx, DoesNotExist()) + end end + unwrap(val), _back end - unwrap(val), back -end _pullback(cx::Context, ::typeof(getproperty), x, f::Symbol) = _pullback(cx, literal_getproperty, x, Val(f)) @@ -230,7 +270,7 @@ _pullback(cx::Context, ::typeof(literal_getindex), x::NamedTuple, ::Val{f}) wher _pullback(cx::Context, ::typeof(literal_getproperty), x::Tuple, ::Val{f}) where f = _pullback(cx, literal_getindex, x, Val(f)) -grad_mut(x) = Ref{Any}(nt_zero(x)) +grad_mut(x::T) where T = Ref{Any}(Composite{T}()) function grad_mut(cx::Context, x) ch = cache(cx) @@ -241,17 +281,17 @@ function grad_mut(cx::Context, x) end end -@adjoint! function setfield!(x, f, val) +function _pullback(__context__::AContext, ::typeof(setfield!), x, f, val) y = setfield!(x, f, val) g = grad_mut(__context__, x) - y, function (_) - Δ = differential2legacy(getfield(g[], f)) - g[] = (;g[]...,pair(Val(f),Zero())...) - (nothing, nothing, Δ) + y, function _back(_) + Δ = getproperty(g[], f) + g[] += Composite{typeof(x)}(;f => -Δ) # i.e. g[].f = Zero(), but that is not implemented + return (DoesNotExist(), DoesNotExist(), DoesNotExist(), Δ) end end -struct Jnew{T,G,splat} +struct Jnew{T,G,splat} # T is the primal type, G is the gradient type g::G end @@ -260,49 +300,63 @@ Jnew{T}(g) where T = Jnew{T,typeof(g)}(g) function _pullback(__context__::AContext, ::typeof(__new__), ::Type{T}, args...) where T x = __new__(T, args...) g = !T.mutable || fieldcount(T) == 0 ? Zero() : grad_mut(__context__, x) - return x, Δ -> gradtuple1(Jnew{T,typeof(g),false}(g)(Δ)) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = diffgradtuple1(Jnew{T,typeof(g),false}(g)(Δ)) + return x, _back end function _pullback(__context__::AContext, ::typeof(__splatnew__), ::Type{T}, args) where T x = __splatnew__(T, args) g = !T.mutable || fieldcount(T) == 0 ? Zero() : grad_mut(__context__, x) - return x, Δ -> gradtuple1(Jnew{T,typeof(g),true}(g)(Δ)) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = diffgradtuple1(Jnew{T,typeof(g),true}(g)(Δ)) + return x, _back end +const allowed_gradient_T = Union{ + NamedTuple, + Nothing, + AbstractZero, + RefValue, + ChainRules.Composite +} + # TODO captured mutables + multiple calls to `back` -@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,AbstractZero,RefValue}) where {T,G} - Δ == Nothing && legacytype_warn() +@generated function (back::Jnew{T,G,false})(Δ::allowed_gradient_T) where {T,G} + Δ <: Union{Nothing, NamedTuple} && legacytype_warn(Δ) if !T.mutable Δ <: AbstractZero && return :Δ end Δ_expr = if G <: AbstractZero :Δ elseif Δ <: RefValue - :(back.g[]) # TODO: is this right? Why don't we need to accum? + :(back.g[]) # TODO: is this right? Why don't we need to accum? else :(accum(back.g[], Δ)) end quote x̄ = $Δ_expr $(G <: AbstractZero || :(back.g[] = nt_zero($Δ_expr))) - return (DoesNotExist(), $(map(f -> :(x̄.$f), fieldnames(T))...)) + return (DoesNotExist(), $(map(fn -> :(x̄.$fn), fieldnames(T))...)) end end -@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,AbstractZero,RefValue}) where {T,G} - Δ == Nothing && legacytype_warn() +@generated function (back::Jnew{T,G,true})(Δ::allowed_gradient_T) where {T,G} + Δ == Union{Nothing, NamedTuple} && legacytype_warn(Δ) if !T.mutable Δ <: AbstractZero && return :Δ end if G <: AbstractZero quote - (DoesNotExist(), ($(map(f -> :(Δ.$f), fieldnames(T))...),)) + return (DoesNotExist(), ($(map(fn -> :(Δ.$fn), fieldnames(T))...),)) end else # TODO is this dead code? back is an (immutable) struct quote x̄ = back.g back.g = nt_zero(back.g) - (DoesNotExist(), ($(map(f -> :(x̄.$f), fieldnames(T))...),)) + return (DoesNotExist(), ($(map(fn -> :(x̄.$fn), fieldnames(T))...),)) end end end diff --git a/src/lib/number.jl b/src/lib/number.jl index 4097863c8..71e469fd6 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -6,7 +6,15 @@ Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing) @adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ) -@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ) + +function _pullback(__context__::AContext, ::Type{T}, x::Real) where T<:Real + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + # Nonsense follows: + # extra DoesNotExist at the start because this is a `:new` not a `:call` + _back(Δ) = (DoesNotExist(), DoesNotExist(), Δ) + return T(x), _back +end for T in Base.uniontypes(Core.BuiltinInts) @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,) @@ -18,8 +26,6 @@ end # Complex Numbers -@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄)) - # we define these here because ChainRules.jl only defines them for x::Union{Real,Complex} @adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) diff --git a/src/lib/statsfuns.jl b/src/lib/statsfuns.jl index 85916cae8..9fa0431de 100644 --- a/src/lib/statsfuns.jl +++ b/src/lib/statsfuns.jl @@ -8,10 +8,12 @@ using Base.Broadcast: broadcasted back(δ) = (dx * δ,) return result, back end -@adjoint function broadcasted(::typeof(xlogx), x::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(xlogx), x::Numeric) result, dx = ∇xlogx(x) - back(δ) = (nothing, unbroadcast(x, δ .* dx)) - return result, back + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* dx)) + return result, _back end function ∇xlogx(x::Numeric) logx = log.(x) @@ -34,9 +36,12 @@ end dx = ∂log1pexp(x) return log1pexp(x), δ -> (δ * dx,) end -@adjoint function broadcasted(::typeof(log1pexp), x::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(log1pexp), x::Numeric) dx = ∂log1pexp.(x) - return log1pexp.(x), δ -> (nothing, unbroadcast(x, δ .* dx)) + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* dx)) + return log1pexp.(x), _back end ∂log1pexp(x::Real) = x < 18.0 ? logistic(x) : x < 33.3 ? one(x) - exp(-x) : oftype(exp(x), 1) ∂log1pexp(x::Float32) = x < 9f0 ? logistic(x) : x < 16f0 ? one(x) - exp(-x) : oftype(exp(x), 1) @@ -51,10 +56,12 @@ end back(δ) = (δ * dx, δ * dy) return result, back end -@adjoint function broadcasted(::typeof(xlogy), x::Numeric, y::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(xlogy), x::Numeric, y::Numeric) result, dx, dy = ∇xlogy(x, y) - back(δ) = (nothing, unbroadcast(x, δ .* dx), unbroadcast(y, δ .* dy)) - return result, back + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* dx), unbroadcast(y, Δ .* dy)) + return result, _back end function ∇xlogy(x::Numeric, y::Numeric) dx = logy = log.(y) @@ -69,10 +76,12 @@ end back(δ) = (δ * dx, δ * dy) return result, back end -@adjoint function broadcasted(::typeof(logaddexp), x::Numeric, y::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(logaddexp), x::Numeric, y::Numeric) result, dx, dy = ∇logaddexp(x, y) - back(δ) = (nothing, unbroadcast(x, δ .* dx), unbroadcast(y, δ .* dy)) - return result, back + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* dx), unbroadcast(y, Δ .* dy)) + return result, _back end function ∇logaddexp(x::Numeric, y::Numeric) result = logaddexp.(x, y) @@ -86,10 +95,12 @@ end back(δ) = (δ * dx, δ * dy) return result, back end -@adjoint function broadcasted(::typeof(logsubexp), x::Numeric, y::Numeric) +function _pullback(__context__::AContext, ::typeof(broadcasted), ::typeof(logsubexp), x::Numeric, y::Numeric) result, dx, dy = ∇logsubexp(x, y) - back(δ) = (nothing, unbroadcast(x, δ .* dx), unbroadcast(y, δ .* dy)) - return result, back + _back(::Nothing) = (legacytype_warn(Nothing); return Zero()) + _back(x::AbstractZero) = x + _back(Δ) = (DoesNotExist(), DoesNotExist(), unbroadcast(x, Δ .* dx), unbroadcast(y, Δ .* dy)) + return result, _back end function ∇logsubexp(x::Numeric, y::Numeric) result = logsubexp.(x, y) diff --git a/test/features.jl b/test/features.jl index 85c1a0356..1ef5356b1 100644 --- a/test/features.jl +++ b/test/features.jl @@ -230,12 +230,14 @@ end[1] == 5 x end[1] == 2 -# Gradient of closure -grad_closure(x) = 2x +# test using call overload form of `@adjoint` +grad_callover(x) = 2x -Zygote.@adjoint (f::typeof(grad_closure))(x) = f(x), Δ -> (1, 2) +# Equivelent to: +# Zygote.@adjoint grad_callover(x) = grad_callover(x), Δ -> (2Δ,) +Zygote.@adjoint (f::typeof(grad_callover))(x) = f(x), Δ -> (1, 2Δ,) # set the gradient wrong so that it is tested -@test gradient((f, x) -> f(x), grad_closure, 5) == (1, 2) +@test gradient((f, x) -> f(x), grad_callover, 5) == (1, 2) invokable(x) = 2x invokable(x::Integer) = 3x diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 021b78704..944a2cf73 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -343,7 +343,23 @@ for mapfunc in [map,pmap,vmap] x = randn(3) _, pb = Zygote.pullback(x -> map(abs2, x), x) Δy = randn(3) - @test first(pb((Δy..., ))) ≈ first(pb(Δy)) + @test length(first(pb(Δy))) == 3 + end + + @testset "_pullback outputs differential type for tuple" begin + _, pb = Zygote._pullback(map, identity, (1,1,1)) + @test pb(Composite{Tuple{Int, Int, Int}}(1,2,3)) == + pb((1, 2, 3)) == + (DoesNotExist(), Zero(), Composite{Tuple{Int64,Int64,Int64}}(1, 2, 3)) + end + + @testset "_pullback outputs differential type for named tuple" begin + _, pb = Zygote._pullback(map, identity, (a=1, b=1, c=1)) + nt = (a=1, b=2, c=3) + cnt = cnt = Composite{typeof(nt), typeof(nt)}(nt) + @test pb(nt) == + pb(cnt) == + (Zero(), Zero(), Composite{NamedTuple{(:a, :b, :c),Tuple{Int64,Int64,Int64}}}(a = 1, b = 2, c = 3)) end end