Skip to content

Commit

Permalink
Merge pull request #1328 from FluxML/bc/rm-getindex-adjoint
Browse files Browse the repository at this point in the history
Excise getindex adjoint
  • Loading branch information
ToucheSir authored Oct 25, 2023
2 parents f755127 + 08e0cd8 commit fbe8271
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 28 deletions.
6 changes: 6 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,9 @@ macro nograd(ex)
end
return blk
end

# Internal function used by some downstream packages.
# Removing this completely would require some tricky registry changes,
# but leaving it as a vestigial function is much easier.
# See https://github.com/FluxML/Zygote.jl/pull/1328 for more context.
function ∇getindex end
18 changes: 0 additions & 18 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,6 @@ end
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)

@adjoint getindex(x::AbstractArray, inds...) = x[inds...], ∇getindex(x, inds)

@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)

∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
if inds isa NTuple{N,Int} && T <: Number
dx = OneElement(dy, inds, axes(x))
elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
dx = _zero(x, eltype(dy))
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
return (_project(x, dx), map(_->nothing, inds)...)
end

"""
OneElement(val, ind, axes) <: AbstractArray
Expand Down
4 changes: 2 additions & 2 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@ end

# Ensure that nothings work with numeric types.
_, back = Zygote.pullback(getindex, randn(4), [1])
@test back([nothing]) == (zeros(4), nothing)
@test back([nothing]) === nothing

# Ensure that nothings work with non-numeric types.
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
@test back([nothing]) == (nothing, nothing)
@test back([nothing]) === nothing
end

@testset "view" begin
Expand Down
10 changes: 2 additions & 8 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,8 @@ using ForwardDiff
using Zygote: hessian_dual, hessian_reverse

@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]

if hess == hessian_dual
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
else
@test_broken hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0] # can't differentiate ∇getindex
@test_broken hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0]
end
@test hess(x -> x[1]*x[2], randn(2)) [0 1; 1 0]
@test hess(((x,y),) -> x*y, randn(2)) [0 1; 1 0] # original docstring version
@test hess(x -> sum(x.^3), [1 2; 3 4]) Diagonal([6, 18, 12, 24])
@test hess(sin, pi/2) -1

Expand Down

0 comments on commit fbe8271

Please sign in to comment.