Skip to content

How To Implement An RNN #51

Answered by danieldjohnson
BeeGass asked this question in Q&A
Discussion options

You must be logged in to vote

Thanks for the question!

I haven't yet done much with RNNs in Penzai, so I'm not sure yet what the most idiomatic approach would be. But here's my first thoughts.

An RNN cell seems different than most Penzai layers because

  • it needs to take two inputs and produce two outputs, with those inputs/outputs being handled differently, and
  • it needs to initialize its carry.

So it might make sense for LSTMCell to be a different kind of pz.Struct, and not be a subclass of pz.nn.Layer (because a pz.nn.Layer always takes one primary input, not two). Perhaps something like this:

import abc
from penzai.experimental.v2 import pz

class RNNCell(pz.Struct, abc.ABC):
  """Abstract base class for RNN cells."""

Replies: 2 comments 3 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
3 replies
@BeeGass
Comment options

@BeeGass
Comment options

@danieldjohnson
Comment options

Answer selected by BeeGass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants