Skip to content

Commit

Permalink
Fix JAX RNN backend issue. (#924)
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 authored Sep 19, 2023
1 parent c64de55 commit f2c3766
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 3 deletions.
11 changes: 9 additions & 2 deletions keras_core/backend/jax/rnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import contextlib

import tree
from jax import lax
from jax import numpy as jnp

from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common import stateless_scope
from keras_core.utils.nest import pack_sequence_as


Expand Down Expand Up @@ -181,7 +183,12 @@ def _step(states, current_input):

scan_xs = inputs

with StatelessScope():
if stateless_scope.in_stateless_scope():
# Reuse the existing parent stateless scope.
scope = contextlib.nullcontext()
else:
scope = stateless_scope.StatelessScope()
with scope:
# We must use a stateless scope because `scan` will involve
# JAX tracing -- any variable update at this stage would
# be a leak.
Expand Down
3 changes: 3 additions & 0 deletions keras_core/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,12 @@
from keras_core.layers.rnn.conv_lstm2d import ConvLSTM2D
from keras_core.layers.rnn.conv_lstm3d import ConvLSTM3D
from keras_core.layers.rnn.gru import GRU
from keras_core.layers.rnn.gru import GRUCell
from keras_core.layers.rnn.lstm import LSTM
from keras_core.layers.rnn.lstm import LSTMCell
from keras_core.layers.rnn.rnn import RNN
from keras_core.layers.rnn.simple_rnn import SimpleRNN
from keras_core.layers.rnn.simple_rnn import SimpleRNNCell
from keras_core.layers.rnn.stacked_rnn_cells import StackedRNNCells
from keras_core.layers.rnn.time_distributed import TimeDistributed
from keras_core.saving import serialization_lib
Expand Down
18 changes: 18 additions & 0 deletions keras_core/layers/rnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ def call(
initial_state,
)

# Prepopulate the dropout state so that the inner_loop is stateless
# this is particularly important for JAX backend.
self._maybe_config_dropout_masks(self.cell, sequences, initial_state)

last_output, outputs, states = self.inner_loop(
sequences=sequences,
initial_state=initial_state,
Expand Down Expand Up @@ -421,6 +425,20 @@ def call(
return output, *states
return output

def _maybe_config_dropout_masks(self, cell, input_sequence, input_state):
step_input = input_sequence[:, 0, :]
state = (
input_state[0]
if isinstance(input_state, (list, tuple))
else input_state
)
if isinstance(cell, DropoutRNNCell):
cell.get_dropout_mask(step_input)
cell.get_recurrent_dropout_mask(state)
if isinstance(cell, StackedRNNCells):
for c, s in zip(cell.cells, input_state):
self._maybe_config_dropout_masks(c, input_sequence, s)

def _maybe_reset_dropout_masks(self, cell):
if isinstance(cell, DropoutRNNCell):
cell.reset_dropout_mask()
Expand Down
3 changes: 2 additions & 1 deletion keras_core/layers/rnn/stacked_rnn_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ def call(self, inputs, states, training=False, **kwargs):
# Call the cells in order and store the returned states.
new_states = []
for cell, states in zip(self.cells, states):
state_is_list = tree.is_nested(states)
states = list(states) if tree.is_nested(states) else [states]
if isinstance(cell, Layer) and cell._call_has_training_arg:
kwargs["training"] = training
else:
kwargs.pop("training", None)
cell_call_fn = cell.__call__ if callable(cell) else cell.call
inputs, states = cell_call_fn(inputs, states, **kwargs)
if len(states) == 1:
if len(states) == 1 and not state_is_list:
states = states[0]
new_states.append(states)

Expand Down
51 changes: 51 additions & 0 deletions keras_core/layers/rnn/stacked_rnn_cells_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,57 @@ def test_basics(self):
supports_masking=True,
custom_objects={"TwoStatesRNNCell": TwoStatesRNNCell},
)
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": [
layers.SimpleRNNCell(3, dropout=0.1, recurrent_dropout=0.1),
layers.SimpleRNNCell(4, dropout=0.1, recurrent_dropout=0.1),
layers.SimpleRNNCell(5, dropout=0.1, recurrent_dropout=0.1),
],
"return_sequences": True,
},
input_shape=(2, 3, 4),
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": [
layers.GRUCell(3, dropout=0.1, recurrent_dropout=0.1),
layers.GRUCell(4, dropout=0.1, recurrent_dropout=0.1),
layers.GRUCell(5, dropout=0.1, recurrent_dropout=0.1),
],
"return_sequences": True,
},
input_shape=(2, 3, 4),
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)
self.run_layer_test(
layers.RNN,
init_kwargs={
"cell": [
layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),
layers.LSTMCell(4, dropout=0.1, recurrent_dropout=0.1),
layers.LSTMCell(5, dropout=0.1, recurrent_dropout=0.1),
],
"return_sequences": True,
},
input_shape=(2, 3, 4),
expected_output_shape=(2, 3, 5),
expected_num_trainable_weights=9,
expected_num_non_trainable_weights=0,
expected_num_seed_generators=0,
supports_masking=True,
)

def test_correctness_single_state_stack(self):
sequence = np.arange(24).reshape((2, 3, 4)).astype("float32")
Expand Down

0 comments on commit f2c3766

Please sign in to comment.