Skip to content

Commit

Permalink
Merge pull request #25 from FluxML/bc/cr-zeroes
Browse files Browse the repository at this point in the history
Pass through ChainRules zeros in @adjoint too
  • Loading branch information
ToucheSir authored Mar 21, 2023
2 parents da51da7 + 369160f commit 19414a4
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Manifest.toml
21 changes: 0 additions & 21 deletions Manifest.toml

This file was deleted.

1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

[compat]
Expand Down
11 changes: 9 additions & 2 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MacroTools
using MacroTools: @q, combinedef
using ChainRulesCore: AbstractZero

function named(arg)
if isexpr(arg, :(::)) && length(arg.args) == 1
Expand Down Expand Up @@ -63,13 +64,19 @@ function gradm(ex, mut = false, keepthunks = false)
$adj
@inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...))
$(mut ? nothing : :(back(::Nothing) = nothing))
$(mut ? nothing : quote
back(::Nothing) = nothing
back::AbstractZero) = $gradtuple(ntuple(_ -> Δ, $(length(args))))
end)
back(Δ) = $gradtuple(_back($maybe_unthunked_Δ))
return y, back
end
@inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...)
y, _back = adjoint(__context__, $f, $(argnames...); kw...)
$(mut ? nothing : :(back(::Nothing) = nothing))
$(mut ? nothing : quote
back(::Nothing) = nothing
back::AbstractZero) = $gradtuplekw(ntuple(_ -> Δ, $(length(args))))
end)
back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ))
return y, back
end
Expand Down

0 comments on commit 19414a4

Please sign in to comment.