Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux new explicit API does not work but old implicit API works for a simple RNN #2341

Open
liuyxpp opened this issue Sep 27, 2023 · 4 comments · May be fixed by #2500
Open

Flux new explicit API does not work but old implicit API works for a simple RNN #2341

liuyxpp opened this issue Sep 27, 2023 · 4 comments · May be fixed by #2500
Labels
Milestone

Comments

@liuyxpp
Copy link

liuyxpp commented Sep 27, 2023

I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is

using Flux

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(x)
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(x, y) = abs(sum(eval_model(x) .- y))

evalcb() = @show(sum(loss.(test_data, test_labels)))

ps = Flux.params(model)

opt = Flux.ADAM(0.1)

for epoch in 1:num_epochs
    Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))
end

However, refractor the above code to use the new explicit API, Zygote complains:

ERROR: LoadError: MethodError: no method matching +(::@NamedTuple{cell::@NamedTuple{σ::Nothing, Wi::Matrix{Float32}, Wh::Matrix{Float32}, b::Vector{Float32}, state0::Nothing}, state::Matrix{Float32}}, ::Base.RefValue{Any})

The code is as follows:

using Flux
using Statistics

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(model, x)
    # Comment following line to make it run.
    # However, in the Flux doc, the following line is required.
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(model, x, y) = abs(sum(eval_model(model, x) .- y))

opt_state = Flux.setup(Flux.ADAM(0.1), model)

for epoch in 1:num_epochs
    for (x, y) in zip(train_data, train_labels)
        train_loss, grads = Flux.withgradient(model) do m
            loss(m, x, y)
        end
        Flux.update!(opt_state, model, grads[1])
    end
    test_loss = mean(loss.(Ref(model), test_data, test_labels))
    println("Epoch $epoch, loss = $test_loss")
end

# Following codes also failed to run.
# for epoch in 1:num_epochs
#     Flux.train!(model, zip(train_data, train_labels), opt_state) do m, x, y
#         loss(m, x, y)
#     end
# end

Julia version info

julia> versioninfo()
Julia Version 1.10.0-beta2
Commit a468aa198d0 (2023-08-17 06:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × Intel(R) Xeon(R) Platinum 8362 CPU @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, icelake-server)
  Threads: 1 on 128 virtual cores
@ToucheSir
Copy link
Member

If you're ok with the initial state being non-trainable, then using one of the functions under https://juliadiff.org/ChainRulesCore.jl/stable/api.html#Ignoring-gradients on the reset! line should work. e.g. @ignore_derivatives Flux.reset!(model). Moving the call to reset! outside of the loss function would also do the trick.

@liuyxpp
Copy link
Author

liuyxpp commented Sep 27, 2023

Ah, thanks! Can you explain more why does this fail for explicit mode but not implicit mode?

BTW, if I have extra data to train the initial state for each time sequence, how should I do that?

@ToucheSir
Copy link
Member

I'm not sure why it fails. The RNN API is a weird one because it uses some of the implicit mode machinery even when you use explicit mode.

if I have extra data to train the initial state for each time sequence, how should I do that?

If you want to have separate initial states for each sample like you mentioned in #2185 (comment), the best bet would be to use the underlying RNN cell API (e.g. RNN -> RNNCell) and write your own loop over the timesteps. It'll be more manual work than using the Recur-based API, but it should just work and also avoid the MethodError shown above.

@liuyxpp
Copy link
Author

liuyxpp commented Sep 27, 2023

Got that and I will report back once I figure it out. Many thanks!

@mcabbott mcabbott added the RNN label Oct 2, 2023
@CarloLucibello CarloLucibello added this to the v0.15 milestone Oct 14, 2024
This was referenced Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants