You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is
using Flux
num_samples =1000
num_epochs =50functiongenerate_data(num_samples)
train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in1:num_samples]
train_labels = (v ->sum(v)).(train_data)
test_data =2.* train_data
test_labels =2.* train_labels
train_data, train_labels, test_data, test_labels
end
train_data, train_labels, test_data, test_labels =generate_data(num_samples)
model = Flux.RNN(2, 1, (x -> x))
functioneval_model(x)
Flux.reset!(model)
out = [model(view(x, :, t)) for t inaxes(x, 2)]
out[end] |> first
endloss(x, y) =abs(sum(eval_model(x) .- y))
evalcb() =@show(sum(loss.(test_data, test_labels)))
ps = Flux.params(model)
opt = Flux.ADAM(0.1)
for epoch in1:num_epochs
Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))
end
However, refractor the above code to use the new explicit API, Zygote complains:
using Flux
using Statistics
num_samples =1000
num_epochs =50functiongenerate_data(num_samples)
train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in1:num_samples]
train_labels = (v ->sum(v)).(train_data)
test_data =2.* train_data
test_labels =2.* train_labels
train_data, train_labels, test_data, test_labels
end
train_data, train_labels, test_data, test_labels =generate_data(num_samples)
model = Flux.RNN(2, 1, (x -> x))
functioneval_model(model, x)
# Comment following line to make it run.# However, in the Flux doc, the following line is required.
Flux.reset!(model)
out = [model(view(x, :, t)) for t inaxes(x, 2)]
out[end] |> first
endloss(model, x, y) =abs(sum(eval_model(model, x) .- y))
opt_state = Flux.setup(Flux.ADAM(0.1), model)
for epoch in1:num_epochs
for (x, y) inzip(train_data, train_labels)
train_loss, grads = Flux.withgradient(model) do m
loss(m, x, y)
end
Flux.update!(opt_state, model, grads[1])
end
test_loss =mean(loss.(Ref(model), test_data, test_labels))
println("Epoch $epoch, loss = $test_loss")
end# Following codes also failed to run.# for epoch in 1:num_epochs# Flux.train!(model, zip(train_data, train_labels), opt_state) do m, x, y# loss(m, x, y)# end# end
Julia version info
julia> versioninfo()
Julia Version 1.10.0-beta2
Commit a468aa198d0 (2023-08-17 06:27 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 128 × Intel(R) Xeon(R) Platinum 8362 CPU @ 2.80GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, icelake-server)
Threads: 1 on 128 virtual cores
The text was updated successfully, but these errors were encountered:
If you're ok with the initial state being non-trainable, then using one of the functions under https://juliadiff.org/ChainRulesCore.jl/stable/api.html#Ignoring-gradients on the reset! line should work. e.g. @ignore_derivatives Flux.reset!(model). Moving the call to reset! outside of the loss function would also do the trick.
I'm not sure why it fails. The RNN API is a weird one because it uses some of the implicit mode machinery even when you use explicit mode.
if I have extra data to train the initial state for each time sequence, how should I do that?
If you want to have separate initial states for each sample like you mentioned in #2185 (comment), the best bet would be to use the underlying RNN cell API (e.g. RNN -> RNNCell) and write your own loop over the timesteps. It'll be more manual work than using the Recur-based API, but it should just work and also avoid the MethodError shown above.
I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is
However, refractor the above code to use the new explicit API, Zygote complains:
The code is as follows:
Julia version info
The text was updated successfully, but these errors were encountered: