Parameters In Tensorflow Keras RNN and CUDNN RNN

Introduction

Recurrent Neural Network (RNN) is widely used in AI applications of handwriting recognition, speech recognition, etc. It essentially consists of a series of matrix-vector multiplications and there are two popular gating mechanisms: GRU and LSTM. Tensorflow Keras provides high-level APIs for such operations, which are simple to use and productive, because it can handle the parameter creation, initialization, and other preprocessing before calling the actual libraries, where Tensorflow Keras adopts CUDNN as the backend for GPUs. In comparison, there are demands that people would like to use CUDNN directly in their projects for more efficiency and better control of the data flow. This might need porting the Keras code or simply viewing its outputs as reference. However, Keras and CUDNN takes different means to deal with the parameters, leading to the different layouts of parameters. This often causes a bit of confusion when developers work on both APIs.

This post will check the GRU and LSTM layers from Keras, especially focusing on how the parameters are organized in Keras and what transformations are needed to make the parameters compatible for CUDNN. This post assumes people have sufficient background of RNN. The equations used below are borrowed from the NVIDIA CUDNN documentation.

GRU

GRU equations
it = σ(Wixt + Riht-1 + bWi + bRi)
rt = σ(Wrxt + Rrht-1 + bWr + bRr)
h’t = tanh(Whxt + rt◦(Rhht-1 + bRh) + bWh)
ht = (1 - it) ◦ h’t + it ◦ ht-1

In these equations, the parameters consisit of three kernel weights/biases (Wi, Wr, Wh, bWi, bWr, bWh) and three recurrent weights/biases (Ri, Rr, Rh, bRi, bRr, bRh). These two sets of weights/biases pairs indicate that we apply bias addition whether the weight-input multiplication is for the layer input xt or for the recurrent input ht-1. In fact, the above equations represent a “double bias” mode categorized by CUDNN and apparently there are two other bias modes that apply bias addition only to layer input or only to recurrent input.

Here I would like to point out the xt and ht-1 are both column vectors in the shape of (inputSize, 1) and (hiddenSize, 1) respectively to make the matrix-vector multiplication possible. Therefore, the shape of each W weights is (hiddenSize, inputSize) and the shape of each R weights is (hiddenSize, hiddenSize), while the biases are always in (hiddenSize, 1).

The above discussion is based on CUDNN formula; by constrast, Keras takes a slightly different way to interpret the weight-input multiplication, which is more like xtTWT or ht-1TRT. That way, the weights in Keras are stored in a transposed style compared to CUDNN: the W weights are (inputSize, hiddenSize) and the R weights are (hiddenSize, hiddenSize). As for the specific implementation under the hood, CUDNN stores all weights and biases all together as a single flatten array, while Keras uses three arrays:

  1. Kernel weights: a (inputSize, 3 x hiddenSize) matrix containing the concatenated three W weights.
  2. Recurrent weights: a (hiddenSize, 3 x hiddenSize) matrix containing the concatenated three R weights.
  3. Biases: a (2, 3 x hiddenSize) matrix where one row is the concatenated W biases and another is for R biases.

The following python code is to output the parameters stored in Keras GRU layer:

import tensorflow as tf
import numpy as np
from tensorflow.keras import layers
from tensorflow.python.keras.layers import recurrent_v2

batch_size = 1
seq_len = 2
input_size = 2
hidden_size = 3

tf.random.set_seed(seed=1)
x = tf.random.uniform((seq_len, batch_size, input_size))
gru = layers.GRU(hidden_size, time_major=True,
                 return_sequences=True,
                 kernel_initializer='glorot_uniform',
                 recurrent_initializer='orthogonal',
                 bias_initializer='random_uniform')
y = gru(x)

np.set_printoptions(
    formatter={'float': lambda x: "{0:0.6f}".format(x)})

print("Keras Kernel Weights:", gru.get_weights()[0])
print("Keras Recurrent Weights:", gru.get_weights()[1])
print("Keras Biases:", gru.get_weights()[2])

Here I don’t copy/paste the outputs but visualize them with three colors to represent different types of weights/biases:

  • Green for r weights/biases
  • Red for i weights/biases
  • Yellow for h weights/biases

Keras Kernel Weights:

0.014929 -0.083409 -0.135106 0.727459 0.278675 -0.227695 -0.094435 0.149277 -0.064070
0.373260 -0.460859 0.072019 0.072253 0.073156 -0.325117 -0.577610 0.193369 0.552166

Keras Recurrent Weights:

-0.176383 -0.344644 -0.688634 -0.260896 -0.076115 -0.322728 0.278958 0.004496 0.346469
-0.204532 0.104082 -0.313509 0.492178 0.236306 0.117206 0.519950 -0.085155 -0.509539
0.308245 0.050380 -0.253974 -0.538845 0.241279 0.437976 -0.030054 -0.501773 -0.211831

Keras Biases:

-0.026355 -0.026123 0.000363 0.027354 0.011077 0.037218 -0.022715 0.011832 -0.029748
0.037008 -0.000759 -0.000307 -0.046988 0.018576 0.013157 -0.029216 -0.006088 -0.031105

Moreover, I add some printfs in the backend to trace the inputs/outputs of CUDNN calls. The parameters directly taken in by CUDNN are as below and I also tint them with the three colors.

CUDNN Parameters (Weights and Biases):

0.727459 0.072253 0.278675 0.073156 -0.227695 -0.325117 0.014929 0.373260 -0.083409
-0.460859 -0.135106 0.072019 -0.094435 -0.577610 0.149277 0.193369 -0.064070 0.552166
-0.260896 0.492178 -0.538845 -0.076115 0.236306 0.241279 -0.322728 0.117206 0.437976
-0.176383 -0.204532 0.308245 -0.344644 0.104082 0.050380 -0.688634 -0.313509 -0.253974
0.278958 0.519950 -0.030054 0.004496 -0.085155 -0.501773 0.346469 -0.509539 -0.211831
0.027354 0.011077 0.037218 -0.026355 -0.026123 0.000363 -0.022715 0.011832 -0.029748
-0.046988 0.018576 0.013157 0.037008 -0.000759 -0.000307 -0.029216 -0.006088 -0.031105

By comparing the parameters in Keras and CUDNN, we can observe three differences:

  1. CUDNN uses a single flat array, while Keras uses three separate arrays.
  2. The weights are transposed between CUDNN and Keras.
  3. The i and r weights/biases are swapped between CUDNN and Keras.

So, if one is given some weights/biases extracted from Keras and their task is to use them in CUDNN, they have to perform a series of preprocessing on them, such as slicing, permuting, tranposing, concatenation, etc. Fortunately, we can find a “secret” tool from Tensorflow Keras that can partially help over the preprocessing:

params = recurrent_v2._canonical_to_params(
    weights=[
        gru.get_weights()[0][:, hidden_size:hidden_size * 2],
        gru.get_weights()[0][:, :hidden_size],
        gru.get_weights()[0][:, hidden_size * 2:],
        gru.get_weights()[1][:, hidden_size:hidden_size * 2],
        gru.get_weights()[1][:, :hidden_size],
        gru.get_weights()[1][:, hidden_size * 2:],
    ],
    biases=[
        gru.get_weights()[2][0][hidden_size:hidden_size * 2],
        gru.get_weights()[2][0][:hidden_size],
        gru.get_weights()[2][0][hidden_size * 2:hidden_size * 3],
        gru.get_weights()[2][1][hidden_size:hidden_size * 2],
        gru.get_weights()[2][1][:hidden_size],
        gru.get_weights()[2][1][hidden_size * 2:hidden_size * 3],
    ],
    shape=tf.constant([-1]),
    transpose_weights=True)
print("CUDNN-equivalent Params:", params)

Note, as we can see, though it is a single function call, we still need to do most of the slicing operations and don’t forget to swap the i and r (i.e., the 0th and 1st) weights/biases.

LSTM

LSTM equations
it = σ(Wixt + Riht-1 + bWi + bRi)
ft = σ(Wfxt + Rfht-1 + bWf + bRf)
ot = σ(Woxt + Roht-1 + bWo + bRo)
c’t = tanh(Wcxt + Rcht-1 + bWc + bRc)
ct = ft ◦ ct-1 + it ◦ c’t
ht = ot ◦ tanh(ct)

It would be a lot easier to understand the LSTM params format when we have learned it in the GRU. The LSTM equations show that there are four kernel weights/biases (Wi, Wf, Wo, Wc, bWi, bWf, bWo, bWc) and four recurrent weights/biases (Ri, Rf, Ro, Rc, bRi, bRf, bRo, bRc). The shapes of the weights and biases are same with those in GRU case.

Still, the two sets of weights/biases in the equations demonstrate the “double bias” mode used by CUDNN. In contrast, Keras adopts a mode more close to the “single bias” mode that only performs the bias addition once for each equation above. So, considering the transposed weights and “single bias” mode, the three Keras LSTM arrays are:

  1. Kernel weights: a (inputSize, 4 x hiddenSize) matrix containing the concatenated four W weights.
  2. Recurrent weights: a (hiddenSize, 4 x hiddenSize) matrix containing the concatenated four R weights.
  3. Biases: a (4 x hiddenSize) vector containing the concatenated biases.

The following python code is to output the parameters stored in Keras LSTM layer:

lstm = layers.LSTM(hidden_size, time_major=True,
                   return_sequences=True,
                   kernel_initializer='glorot_uniform',
                   recurrent_initializer='orthogonal',
                   bias_initializer='random_uniform')
y = lstm(x)
print("Keras Kernel Weights:", lstm.get_weights()[0])
print("Keras Recurrent Weights:", lstm.get_weights()[1])
print("Keras Biases:", lstm.get_weights()[2])

Here I visualize the outputs with four colors to represent different types of weights/biases:

  • Green for i weights/biases
  • Red for f weights/biases
  • Yellow for o weights/biases
  • Blue for c weights/biases

Keras Kernel Weights:

0.307402 -0.468454 -0.571665 -0.406933 0.390397 0.267421 -0.119232 0.018690 -0.560165
-0.309438 0.163861 0.202521 -0.397582 0.334114 -0.077433 -0.450064 0.124535 0.564949
-0.202529 0.328128 -0.453909
-0.374840 0.154384 -0.276332

Keras Recurrent Weights:

-0.338174 -0.019739 0.702717 0.173684 -0.237763 -0.398269 -0.122475 0.061238 0.148485
-0.202324 -0.259554 0.264483 -0.176437 0.164398 0.278202 0.151397 0.039010 0.493140
-0.114759 -0.399628 -0.053830 0.166763 0.137982 -0.207373 0.150091 0.639458 -0.216613
0.106563 0.249839 0.177616
-0.168453 -0.028650 -0.623991
0.321846 -0.380653 -0.086838

Keras Biases:

0.049217 0.048934 0.007049 1.000000 1.000000 1.000000 -0.020231 0.046288 -0.007113
-0.013948 -0.023413 -0.001040

Similarly, I tint the parameters taken in by CUDNN with the four colors.

CUDNN Parameters (Weights and Biases):

0.307402 -0.309438 -0.468454 0.163861 -0.571665 0.202521 -0.406933 -0.397582 0.390397
0.334114 0.267421 -0.077433 -0.119232 -0.450064 0.018690 0.124535 -0.560165 0.564949
-0.202529 -0.374840 0.328128 0.154384 -0.453909 -0.276332 -0.338174 -0.202324 -0.114759
-0.019739 -0.259554 -0.399628 0.702717 0.264483 -0.053830 0.173684 -0.176437 0.166763
-0.237763 0.164398 0.137982 -0.398269 0.278202 -0.207373 -0.122475 0.151397 0.150091
0.061238 0.039010 0.639458 0.148485 0.493140 -0.216613 0.106563 -0.168453 0.321846
0.249839 -0.028650 -0.380653 0.177616 -0.623991 -0.086838 0.000000 0.000000 0.000000
0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
0.049217 0.048934 0.007049 1.000000 1.000000 1.000000 -0.020231 0.046288 -0.007113
-0.013948 -0.023413 -0.001040

By comparing the parameters in Keras and CUDNN, we can still observe three major differences:

  1. CUDNN uses a single flat array, while Keras uses four separate arrays.
  2. The weights are transposed between CUDNN and Keras.
  3. The order of weights/biases are same between CUDNN and Keras, but the bias part is padded with (4 x hiddenSize) zeros to fill the gaps of the “double bias” (CUDNN) and “single bias” (Keras) modes.

Now, we take advantage of the “secret” function to convert the Keras parameters to be CUDNN compatible:

params = recurrent_v2._canonical_to_params(
    weights=[
        lstm.get_weights()[0][:, :hidden_size],
        lstm.get_weights()[0][:, hidden_size:hidden_size * 2],
        lstm.get_weights()[0][:, hidden_size * 2:hidden_size * 3],
        lstm.get_weights()[0][:, hidden_size * 3:],
        lstm.get_weights()[1][:, :hidden_size],
        lstm.get_weights()[1][:, hidden_size:hidden_size * 2],
        lstm.get_weights()[1][:, hidden_size * 2:hidden_size * 3],
        lstm.get_weights()[1][:, hidden_size * 3:],
    ],
    biases=[
        tf.zeros((hidden_size,)),
        tf.zeros((hidden_size,)),
        tf.zeros((hidden_size,)),
        tf.zeros((hidden_size,)),
        lstm.get_weights()[2][:hidden_size],
        lstm.get_weights()[2][hidden_size:hidden_size * 2],
        lstm.get_weights()[2][hidden_size * 2:hidden_size * 3],
        lstm.get_weights()[2][hidden_size * 3:hidden_size * 4],
    ],
    shape=tf.constant([-1]),
    transpose_weights=True)
print("CUDNN-equivalent Params:", params)

Note, for LSTM, we don’t need to change orders of the four weights/biases but do need to pad one row of biases to be zeros.

Reference

2022

CUDA Tips: nvcc’s -code, -arch, -gencode

1 minute read

Introduction People may feel confused by the options of -code, -arch, -gencode when compiling their CUDA codes. Although the official guidance explains the d...

Back to top ↑

2021

Demystifying the BatchNorm-Add-ReLU Fusion

2 minute read

Introduction My previous post, “Demystifying the Conv-Bias-ReLU Fusion”, has introduced a common fusion pattern in deep learning models. This post, on the ot...

Topological Sorting Explained

5 minute read

Introduction Recently I was working on a project related to the operation fusion in Tensorflow. My previous posts have covered several topics, such as how to...

Demystifying the Conv-Bias-ReLU Fusion

7 minute read

Introduction My previous post, “Fused Operations in Tensorflow”, introduced the basics of operation fusion in deep learning by showing how to enable the grap...

Fused Operations in Tensorflow

5 minute read

Introduction The computations in deep learning models are usually represented by a graph. Typically, operations in the graph are executed one by one, and eac...

Back to top ↑

2020

Inside Normalizations of Tensorflow

5 minute read

Introduction Recently I came across with optimizing the normalization layers in Tensorflow. Most online articles are talking about the mathematical definitio...

Back to top ↑