Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add support for ChainRules Composite type #806

Open
wants to merge 38 commits into
base: chainrules_types
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8b4826c
do not convert chain rules output to named tuple
Sep 21, 2020
ee9cc0a
using Composite
Sep 30, 2020
407eb34
add the old changes
Oct 9, 2020
07598a0
Update with changes to mz/cr-types
Oct 9, 2020
bac7fa9
remove some changes
Oct 9, 2020
0d65140
add legacytype_warn
Oct 9, 2020
5cb1036
Merge branch 'mz/cr-types' into mz/cr-composite
Oct 9, 2020
dbb7675
add Composite to allowed gradients
Oct 9, 2020
33df078
Merge branch 'mz/cr-types' into mz/cr-composite
Oct 9, 2020
b84d0a9
__new__ and __splatnew__ add Composite support
Oct 9, 2020
a20a6e1
improve readability
Oct 14, 2020
457826b
remove __new__ changes
Oct 16, 2020
a02a649
move to new warnings with types passed
Oct 16, 2020
d4b7e49
fix the iterator
Oct 16, 2020
78a7170
add warnings to chainrules
Oct 17, 2020
a9c0b7a
Composite{Any} -> Composite(typeof(g)}
Oct 21, 2020
09652d9
allowed gradient types change
Oct 21, 2020
6f609e9
remove some legacy2differential instances
Oct 22, 2020
98cd45e
accum_sum and unbroadcast to differential types
Oct 26, 2020
d3bb7fa
remove some gradtuple1
Oct 26, 2020
22f5ae5
fix literal getproperty as _pullback
Oct 27, 2020
8d9810f
fix cholesky adjoints
Oct 28, 2020
2ae657f
fix call overload
Nov 3, 2020
a58fc37
fix Real constructor
Nov 3, 2020
96633d1
change to diffgradtuple
Nov 3, 2020
cae7b62
Core._apply tuples to Composites
Nov 3, 2020
3ec66c9
fix tasks
Nov 5, 2020
33ed358
fix the warnings
Nov 5, 2020
1528485
change the test
Nov 5, 2020
60c3d72
improve type stability
Nov 5, 2020
2040fc1
move to new version of l2d taking a primal rather than primal type
Nov 16, 2020
697792e
fix _back
Nov 17, 2020
48102d8
nested ad fix first draft
Nov 20, 2020
cb13531
fix map a bit
Nov 20, 2020
925b471
add map tests and fix reverse of composite
Nov 24, 2020
453df12
l2d
Nov 24, 2020
db0662f
move getindex to _pullback
Nov 24, 2020
5bdc8b1
Core.getfield to _pullback
Nov 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using LinearAlgebra: copytri!, AbstractTriangular
using ArrayLayouts: MemoryLayout, AbstractColumnMajor

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty
using ZygoteRules: differential2legacy, legacy2differential, legacytype_warn, gradtuple1
using ZygoteRules: ZygoteRules, differential2legacy, legacy2differential, legacytype_warn, diffgradtuple1

using ChainRules: ChainRules, rrule, unthunk, AbstractZero, Zero, DoesNotExist
using ChainRules: ChainRules, rrule, unthunk, AbstractZero, Zero, DoesNotExist, Composite
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
7 changes: 5 additions & 2 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = Zero()
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = x
#=
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
Expand All @@ -52,15 +53,17 @@ for T_outer in (:Tuple, :NamedTuple)
convert($T_outer, xp)
end
end
=#

"""
wrap_chainrules_input(x)

Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline wrap_chainrules_input(::Nothing) = (legacytype_warn(Nothing); return ChainRules.Zero())
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
legacytype_warn(typeof(xs))
xp = map(wrap_chainrules_input, xs)
ChainRules.Composite{Any, typeof(xp)}(xp)
end
Expand All @@ -77,7 +80,7 @@ end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = (legacytype_warn(); return Zero())
@inline (s::ZBack)(::Nothing) = (legacytype_warn(Nothing); return Zero())

"""
chain_rrule(f, args...)
Expand Down
27 changes: 23 additions & 4 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,28 @@ tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)

function pullback(f, args...)
y, back = _pullback(f, args...)
y, Δ -> tailmemaybe(differential2legacy(back(legacy2differential(Δ))))
y, _back = _pullback(f, args...)
y, Δ -> tailmemaybe(differential2legacy(_back(ZygoteRules.l2d(Δ, y))))
end

#==
function _pullback(__context__::AContext, ::typeof(pullback), f, args...)
lesser_y, _lesser_back = _pullback(__context__, f, args...)
lesser_back = Δ -> tailmemaybe(differential2legacy(_lesser_back(ZygoteRules.l2d(Δ, lesser_y))))
@show f
greater_y = lesser_y, lesser_back
@show 2
function greater_back(greater_Δ)
@show 3
Δlesser_y, Δlesser_back = greater_Δ
greater_y_again, second_back = _pullback(__context__, _lesser_back, Δlesser_y)
Δf_and_Δargs = second_back(greater_Δ)
Δf, Δargs = Iterators.peel(Δf_and_Δargs)
(Zero(), Δf, Δargs...)
end
return greater_y, greater_back
end
==#

sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
Expand Down Expand Up @@ -169,12 +188,12 @@ end

function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
y, _back = _pullback(cx, f)
y, function (Δ)
for p in ps
cache(cx)[p] = nothing
end
differential2legacy(back(legacy2differential(Δ))) # has non-local effects via accum (?)
differential2legacy(_back(ZygoteRules.l2d(Δ, y))) # has non-local effects via accum (?)
Grads(cx.cache, ps) # TODO make a copy
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
end
if g == nothing # No IR found
Δ <: AbstractZero && return :(Δ)
Δ == Nothing && (legacytype_warn(); return :(DoesNotExist()))
Δ == Nothing && (legacytype_warn(Nothing); return :(DoesNotExist()))
return :(error("Non-differentiable function $(repr(j.t[1]))"))
end
meta, _, back = g
Expand Down
8 changes: 3 additions & 5 deletions src/compiler/reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using IRTools: IR, Variable, Pipe, xcall, var, prewalk, postwalk,
iscall(x, m::Module, n::Symbol) = isexpr(x, :call) && x.args[1] == GlobalRef(m, n)

gradindex(x, i) = x[i]
gradindex(::Nothing, i) = (legacytype_warn(); return DoesNotExist())
gradindex(::Nothing, i) = (legacytype_warn(Nothing); return DoesNotExist())
gradindex(x::AbstractZero, i) = x
xgetindex(x, i...) = xcall(Base, :getindex, x, i...)
xgradindex(x, i) = xcall(Zygote, :gradindex, x, i)
Expand Down Expand Up @@ -237,12 +237,10 @@ function adjoint(pr::Primal)
for v in reverse(keys(b))
ex = b[v].expr
if haskey(pr.pullbacks, v)
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)),
line = b[v].line))
g = push!(rb, stmt(Expr(:call, alpha(pr.pullbacks[v]), grad(v)), line = b[v].line))
for (i, x) in enumerate(ex.args)
x isa Variable || continue
grad(x, push!(rb, stmt(xgradindex(g, i),
line = b[v].line)))
grad(x, push!(rb, stmt(xgradindex(g, i), line = b[v].line)))
end
elseif ex isa Core.PiNode
grads[ex.val] = grads[v]
Expand Down
56 changes: 40 additions & 16 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ function unzip(tuples)
_unzip(tuples, Val(N))
end

# if we have an iterator, unzip may (will) lose track of what the outer type should be
# First arg is the primal iterator type to reconstruct to
reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:Union{UnitRange, StepRange} = diff_iter
reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:AbstractArray = convert(T, diff_iter)
reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:Tuple = Composite{T}(diff_iter...)
reconstruct_differential_from_iterator(::Type{T}, diff_iter) where T<:NamedTuple = Composite{T}(;NamedTuple{fieldnames(T)}(diff_iter)...)

# TODO: piracy, move to ChainRulesCore.jl
Base.convert(::Type{T}, x::AbstractZero) where T <: Real = zero(T)

# Reverse iteration order when ∇map is applied to vector,
# needed for stateful functions.
# See https://github.com/FluxML/Flux.jl/issues/1209
Expand All @@ -172,39 +182,53 @@ _tryreverse(m, backs, Δ) = backs, Δ
function _tryreverse(m::typeof(map), backs, Δ::Union{AbstractVector, Tuple})
return reverse(backs), reverse(Δ)
end
function _tryreverse(m::typeof(map), backs, Δ::Composite)
return reverse(backs), _tryreverse(m, Δ)
end
_tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
_tryreverse(m::typeof(map), c::Composite) = Composite{typeof(reverse(c.backing)), typeof(reverse(c.backing))}(reverse(c.backing))

for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap),(:vmap,:∇vmap)]
@eval function $∇mapfunc(cx, f, args...)
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
return ys_and_backs, _ -> Zero()
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = $mapfunc(
(f, δ) -> differential2legacy(f(legacy2differential(δ))),
(f, δ) -> f(δ),
_tryreverse($mapfunc, backs, Δ)...
)
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
Δf = reduce(accum, Δf_and_args[1])
(Δf, Δf_and_args[2:end]...)
Δargs_raw = Δf_and_args[2:end]
Δargs = ntuple(length(args)) do i
@show typeof(args[i])
@show Δargs_raw[i]
return reconstruct_differential_from_iterator(typeof(args[i]), Δargs_raw[i])
end
return (Δf, Δargs...)
end
end
end

@eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...)
$∇mapfunc(__context__, f, args...)
@eval function _pullback(__context__::AContext, ::typeof($mapfunc), f, args::Union{AbstractArray,Tuple}...)
ys, _f_back = $∇mapfunc(__context__, f, args...)
_back(::Nothing) = (legacytype_warn(Nothing); return Zero())
_back(x::AbstractZero) = x
_back(Δ) = diffgradtuple1(_f_back(Δ))
return ys, _back
end
end

function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
y, back = ∇map(cx, g.f, g.iter)
y, _back = ∇map(cx, g.f, g.iter)
y, function (ȳ)
f̄, x̄ = legacy2differential(back(differential2legacy(ȳ)))
(DoesNotExist(), (f = f̄, iter = x̄),)
f̄, x̄ = _back(ȳ)
(DoesNotExist(), Composite{typeof(g)}(f = f̄, iter = x̄),)
end
end

Expand Down Expand Up @@ -456,7 +480,7 @@ end
Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B)
return Y, function(Ȳ)
Ā_factors, B̄ = back(Ȳ)
return ((uplo=nothing, status=nothing, factors=Ā_factors), B̄)
return ((factors=Ā_factors, uplo=nothing, info=nothing), B̄)
end
end

Expand Down Expand Up @@ -513,7 +537,7 @@ end
C = cholesky(Σ, check = check)
return C, Δ::NamedTuple -> begin
issuccess(C) || throw(PosDefException(C.info))
return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing
return (Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)),)
end
end

Expand Down Expand Up @@ -752,30 +776,30 @@ end
# Various sensitivities for `literal_getproperty`, depending on the 2nd argument.
@adjoint function literal_getproperty(C::Cholesky, ::Val{:uplo})
return literal_getproperty(C, Val(:uplo)), function(Δ)
return ((uplo=nothing, info=nothing, factors=nothing),)
return ((factors=nothing, uplo=nothing, info=nothing), nothing)
end
end
@adjoint function literal_getproperty(C::Cholesky, ::Val{:info})
@adjoint function literal_getproperty(C::Cholesky, ::Val{:info}) # TODO make sure these work by changing the @adjoint macro
return literal_getproperty(C, Val(:info)), function(Δ)
return ((uplo=nothing, info=nothing, factors=nothing),)
return ((factors=nothing, uplo=nothing, info=nothing), nothing)
end
end
@adjoint function literal_getproperty(C::Cholesky, ::Val{:U})
return literal_getproperty(C, Val(:U)), function(Δ)
Δ_factors = C.uplo == 'U' ? UpperTriangular(Δ) : LowerTriangular(copy(Δ'))
return ((uplo=nothing, info=nothing, factors=Δ_factors),)
return ((factors=Δ_factors, uplo=nothing, info=nothing), nothing)
end
end
@adjoint function literal_getproperty(C::Cholesky, ::Val{:L})
return literal_getproperty(C, Val(:L)), function(Δ)
Δ_factors = C.uplo == 'L' ? LowerTriangular(Δ) : UpperTriangular(copy(Δ'))
return ((uplo=nothing, info=nothing, factors=Δ_factors),)
return ((factors=Δ_factors, uplo=nothing, info=nothing), nothing)
end
end

@adjoint function logdet(C::Cholesky)
return logdet(C), function(Δ)
return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),)
return ((factors=Diagonal(2 .* Δ ./ diag(C.factors)), uplo=nothing, info=nothing),)
end
end

Expand Down
26 changes: 13 additions & 13 deletions src/lib/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ grad_mut(ch::Channel) = Channel(ch.sz_max)
@adjoint! function put!(ch::Channel, x)
put!(ch, x), function (ȳ)
x̄ = take!(grad_mut(__context__, ch))
(nothing, accum(x̄, ȳ), nothing)
return (nothing, accum(x̄, ȳ))
end
end

Expand All @@ -64,36 +64,36 @@ end
end
end

@adjoint! function Task(f)
function _pullback(__context__::AContext, ::Type{<:Task}, f)
t = Task(f)
t.code = function ()
y, back = _pullback(__context__, f)
cache(__context__)[t] = Task(back)
y, _back = _pullback(__context__, f)
cache(__context__)[t] = Task(_back) # when `fetch`ed, this returns a tuple: (f̄,)
return y
end
t, _ -> fetch(cache(__context__)[t])
return t, _ -> (DoesNotExist(), first(fetch(cache(__context__)[t])))
end

function runadjoint(cx, t, ȳ = nothing)
function runadjoint(cx, t, ȳ = DoesNotExist())
t̄ = cache(cx)[t]
f = t̄.code
t̄.code = () -> differential2legacy(f(legacy2differential(ȳ)))
t̄.code = () -> f(ȳ)
@static if VERSION > v"1.3-"
t̄.sticky = t.sticky
end
schedule(t̄)
end

@adjoint! function wait(t::Task)
wait(t), _ -> (runadjoint(__context__, t); nothing)
function _pullback(__context__::AContext, ::typeof(wait), t::Task)
wait(t), _ -> (runadjoint(__context__, t); DoesNotExist())
end

@adjoint! function fetch(t::Task)
fetch(t), ȳ -> (runadjoint(__context__, t, ȳ); nothing)
function _pullback(__context__::AContext, ::typeof(fetch), t::Task)
fetch(t), ȳ -> (runadjoint(__context__, t, ȳ); DoesNotExist())
end

@adjoint! function Base.sync_end(refs)
Base.sync_end(refs), _ -> foreach(t -> runadjoint(__context__, t), refs)
function _pullback(__context__::AContext, ::typeof(Base.sync_end), refs)
Base.sync_end(refs), _ -> (foreach(t -> runadjoint(__context__, t), refs); DoesNotExist())
end

# Make @sync work
Expand Down
Loading