Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Un-collapse nothings in gradient #1495

Merged
merged 2 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ Requires = "1.1"
SpecialFunctions = "1.6, 2"
Statistics = "1"
Tracker = "0.2"
ZygoteRules = "0.2.4"
ZygoteRules = "0.2.5"
julia = "1.6"

[extras]
Expand Down
66 changes: 60 additions & 6 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,52 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)

"""
pullback(f, args...)
pullback(f, ::Params)

Returns the value of the function `f` and a back-propagator function,
which can be called to obtain a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or gradient.

```julia
y, back = pullback(f, args...)
∇ = back(seed)
```

`back` must be called with a start value `seed` matching the output of `f(args...)`.
If `f(args...)` returns a number, `seed` should be a number.
If `f(args...)` returns an array, `seed` should be an equally-sized array.

See also [`withgradient`](@ref) to obtain the value and gradients in one call,
and [`gradient`](@ref) for obtaining just the gradients.

```jldoctest; setup=:(using Zygote)
julia> y, back = pullback(*, 2.0, 3.0, 5.0);

julia> y
30.0

julia> back(1.0)
(15.0, 10.0, 6.0)

julia> back(2.0)
(30.0, 20.0, 12.0)

julia> y, back = pullback(x -> [x, x], 1.0);

julia> y
2-element Vector{Float64}:
1.0
1.0

julia> back([1.0, 1.0])
(2.0,)

julia> back([2.0, nothing])
(2.0,)
```
"""
@inline pullback(f, args...) = pullback(f, Context(), args...)
function pullback(f, cx::AContext, args...)
y, back = _pullback(cx, f, args...)
Expand Down Expand Up @@ -67,11 +113,16 @@ sensitivity(y::Complex) = error("Output is complex, so the gradient is not defin
sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")

# Preserves output as tuple when gradients are collapsed
_project_all(::NTuple{N}, ::Nothing) where {N} = ntuple(_ -> nothing, N)
_project_all(x::Tuple, dx::Tuple) = map(_project, x, dx)

"""
gradient(f, args...)

Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar `x`) or the gradient.
If no gradient is defined, `∂f/∂x` will be `nothing`.

`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.

Expand All @@ -95,7 +146,7 @@ julia> gradient([7, 11], 0, 1) do x, y, d
function gradient(f, args...)
y, back = pullback(f, args...)
grad = back(sensitivity(y))
isnothing(grad) ? nothing : map(_project, args, grad)
return _project_all(args, grad)
end

# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
Expand All @@ -109,7 +160,7 @@ end
withgradient(f, ::Params)

Returns both the value of the function and the [`gradient`](@ref),
as a named tuple.
as a named tuple.

```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
Expand Down Expand Up @@ -161,7 +212,7 @@ function withgradient(f, args...)
else
back(sensitivity(y))
end
results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
results = _project_all(args, grad)
(val=y, grad=results)
end

Expand Down Expand Up @@ -304,7 +355,7 @@ end
Grads(...)

Dictionary-like container returned when taking gradients with
respect to implicit parameters. For an array `W`, appearing
respect to implicit parameters. For an array `W`, appearing
within `Params([W, A, B...])`, the gradient is `g[W]`.
"""
struct Grads
Expand All @@ -321,7 +372,7 @@ const ADictOrGrads = Union{AbstractDict, Grads}

# Dictionary interface.
# Don't use the IdDict directly since it may contain some spurious pairs.
Base.haskey(gs::Grads, x) = x ∈ gs.params
Base.haskey(gs::Grads, x) = x ∈ gs.params
Base.keys(gs::Grads) = gs.params
Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)

Expand Down Expand Up @@ -381,7 +432,7 @@ broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)

function materialize!(gs1::Grads, gs2::Grads)
issetequal(gs1.params, gs2.params) ||
issetequal(gs1.params, gs2.params) ||
throw(ArgumentError("Expected Grads objects with the same Params."))
for p in gs1.params
gs1[p] = gs2[p]
Expand Down Expand Up @@ -421,6 +472,9 @@ function pullback(f, ps::Params)
end
end

# No conversion required here
_project_all(_, dx::Grads) = dx

# Code Reflection

function code_ir(f, T)
Expand Down
4 changes: 2 additions & 2 deletions test/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
@test gradient(floor, 1) === (0.0,)
@test gradient(ceil, 1) === (0.0,)
@test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
@test gradient(hash, 1) === (nothing,)
@test gradient(div, 1, 2) === (nothing, nothing)
end

@testset "basics" begin
Expand Down
2 changes: 1 addition & 1 deletion test/structures.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ end
end

m, b = Zygote._pullback(Zygote.Context(), nameof, M)
@test b(m) == (nothing, nothing)
@test b(m) === nothing
end
Loading