Inspecting GRU gates using Tensorflow 2.0 and Keras

In the past weeks I've been working on some neural network models on a regression task. Given the temporal nature of my data, a recurrent models is a good fit and I decided to use a GRU layer in my architecture.

If you're reading this, it's likely that you already know what a Gated Recurrent Unit is, if not, then this and this will give you context.

Long story short, a GRU is a recurrent neural network defined by the following equations:

\[ z_{t} = σ_{g}(W_{z}x_{t}+U_{z}h_{t-1}+b_{z}) \]

\[ r_{t} = σ_{g}(W_{r}x_{t}+U_{r}h_{t-1}+b_{r}) \]

\[ h_{t} = z_{t}\odot h_{t-1}+(1-z_{t})\odot \phi_{h}(W_{h}x_{t}+U_{h}(r_{t}\odot h_{t-1})+b_{h})\]

Here $z_{t}$ and $r_{t}$ are the update and reset gate. They control how much of the current candidate state get mixed with the previous one, before outputting it.

For debugging reasons, other than the sequences of the outputs and the last state I wanted to look at gate's activations. Unfortunately despite my expectations that wasn't a straightforward thing to do, so I'm writing this to help others that may have the same need.

I tried the naive approach first: modify recurrent.py of Tensorflow to save the tensors. Of course it didn't work.

Next try: create a custom Layer with two class variables to keep the two lists of tensors. No luck. You can't access those variables later using a getter method (a.k.a. If you don't read all the code, you can't expect things to work)

After this two attempts I followed Tensorflow's docs and started to properly implement a subclass of GRUCell:.

I ended up with this code:

# Imports ...
# Imports ...

class CustomGRUCell(tf.keras.layers.GRUCell):

    def __init__(self, units, **kwargs):
        super(CustomGRUCell, self).__init__(units, **kwargs)

    def build(self, input_shape):
        super(CustomGRUCell, self).build(input_shape)

    def call(self, inputs, states, training=None):
        h_tm1 = states[0]  # previous memory
            # Same GRUCell code, omitted for brevity
            # ....

            z = self.recurrent_activation(x_z + recurrent_z)
            r = self.recurrent_activation(x_r + recurrent_r)

            if self.reset_after:
                recurrent_h = r * recurrent_h
            else:
                recurrent_h = K.dot(r * h_tm1,
                                    self.recurrent_kernel[:, 2 * self.units:])

            hh = self.activation(x_h + recurrent_h)
            # previous and candidate state mixed by update gate
            h = z * h_tm1 + (1 - z) * hh
    
        return (h, r, z), [h] # you can return a tuple as next stete

To access gate values we need to build a model that will output them:


input_tensor = tf.keras.layers.Input(shape=(timesteps, features),
                                     name="input")
cell = CustomGRUCell(256)
s_gru, states = tf.keras.layers.RNN(cell,
                                    return_sequences=True,
                                    return_state=True)(input_tensor)
out = tf.keras.layers.Dense(1, activation='linear', name="out")(s_gru[0])
model = tf.keras.models.Model(inputs=input_tensor,
                              outputs=[out, s_gru[1], s_gru[2]])

Since we are using a multi-output model we must remember to train only on the dense output, without adding the loss values calculated for s_gru[1] and s_gru[2] to the total loss:

model.compile(optimizer=optimizer, loss='mae', loss_weights=[1., 0., 0.])

Tensorflow 1.x

Originally I was trying to achieve this with TF 1.x, that's why I said it wasn't straightforward. Look at this StackOverflow answers: one and two.

I had problem in defining the output_size and state_size properties of the cell: output_size was a read-only property and I was getting errors when using @property annotation in the subclass.

The key is specifying correctly the state_size, that can be a tuple, to sort things.

sudo shutdown -h now
Avatar
Mick Hardins
Artificial Intelligence student

I’m currently living. Chances are that I’ll probably die in the future.