Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RNNs redesign #2500

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open

RNNs redesign #2500

wants to merge 9 commits into from

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Oct 14, 2024

A complete rework of our recurrent layers, making them more similar to their pytorch counterpart.
This is in line with the proposal in #1365 and should allow to hook into the cuDNN machinery (future PR).
Hopefully, this ends the infinite source of troubles that the recurrent layers have been.

  • Recur is no more. Mutating its internal state was a source of problems for AD (explicit differentiation for RNN gives wrong results #2185)
  • Now RNNCell is exported and takes care of the minimal recursion step, i.e. a single time:
    • has forward cell(x , h)
    • x can be of size in or in x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out or out x batch_size
  • RNN instead takes in a (batched) sequence and a (batched) hidden state and returns the hidden state for the whole sequence:
    • has forward rnn(x, h)
    • x can be of size in x len or in x len x batch_size
    • h can be of size out or out x batch_size
    • returns hnew of size out x len or out x len x batch_size
  • LSTM and GRU are similarly changed.

Close #2185, close #2341, close #2258, close #1547, close #807, close #1329

Related to #1678

PR Checklist

  • cpu tests
  • gpu tests
  • if hidden state not given as input, assumed to be zero
  • port LSTM and GRU
  • Entry in NEWS.md
  • Remove reset!
  • Docstrings
  • add an option in constructors to have trainable initial state
  • Benchmarks
  • use cuDNN (future PR)
  • implement the num_layers argument for stacked RNNs (future PR)
  • revisit whole documentation (future PR)
  • add dropout (future PR)

@darsnack
Copy link
Member

darsnack commented Oct 18, 2024

Fully agree with updating the design to be non-mutating. There are two options we've discussed in the past:

  1. y, h = cell(x, h) like here (I guess this PR removes y as a return value which is fine)
  2. y, cell = cell(x) / y, cell = Flux.apply(cell, x)

Option 1 is outlined in this PR so I won't say anything about it.

Option 2 is a more drastic redesign to make all layers (not just recurrent) non-mutating. Why?

  • Do a design that covers stateful layers in general (e.g. norm layers) and not just recurrent cells
  • Keep a nice feature of Flux's current design which is that the model contains all info: parameters, state, flags, etc.

@CarloLucibello
Copy link
Member Author

I thought about Option 2. On the upside, it seems a nice intermediate spot between current Flux and Lux. The downside is that the interface would seem a bit exotic to flux and pytorch users. Moreover, it would be problematic for normalization layers.

Also, we need to distinguish between normalization layers and recurrent layers.

  • Normalization layer at training time update some internal buffers, within a stopgrad barrier. The buffer update has no influence on the output of the layer and the final loss. You typically apply the layer only once during the forward pass. Normalization layers are typically part of larger models (chains or custom structs). Therefore for normalization layers: 1) we haven't had the gradient computation problems we had for recurrent layers; 2) you want the layer with the updated buffer to be inserted back in your model, but this would require a mutating operation or returning a new model.

  • For recurrent layers, Option 2 would be sensible, but is it worth it? Once you adopt the perspective that a cell takes two inputs, x and h, and gives back an output, hnew, all problems disappear. I think we add complexity for no gain in trying to keep the state internal.

@ToucheSir
Copy link
Member

The main benefit for keeping the state "internal" or having it be part of a unified interface like apply would be that Chain works with RNNs again. Whether that's worth the extra complexity is the question. Given our priorities, I think it's best left as future work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment