Developer Guide and Reference

  • 2022.1
  • 04/11/2022
  • Public Content
Contents

RNN

General

The RNN primitive computes a stack of unrolled recurrent cells, as depicted in Figure 1. LaTex Math image., LaTex Math image. and LaTex Math image. are optional parameters (the variable names follow the standard Naming Conventions). If not provided, LaTex Math image. and LaTex Math image. will default to 0.
Figure 1: Example of stacked recurrent cells unrolled over the time dimension and executed with the `left2right` direction. Dashed lines represent optional parameters.
The RNN primitive supports four modes for evaluation direction:
  • left2right
    will process the input data timestamps by increasing order
  • right2left
    will process the input data timestamps by decreasing order
  • bidirectional_concat
    will process all the stacked layers from
    left2right
    and from
    right2left
    independently, and will concatenate the output in LaTex Math image. over the channel dimension.
  • bidirectional_sum
    will process all the stacked layers from
    left2right
    and from
    right2left
    independently, and will sum the two outputs to LaTex Math image..
Even though the RNN primitive supports passing a different number of channels for LaTex Math image., LaTex Math image., LaTex Math image., and LaTex Math image., we always require the following conditions in order for the dimension to be consistent:
  • LaTex Math image.,
  • when LaTex Math image., LaTex Math image.,
  • when LaTex Math image., LaTex Math image.,
  • when using the
    bidirectional_concat
    direction, LaTex Math image..
The general formula for the execution of a stack of unrolled recurrent cells depends on the current iteration of the previous layer (LaTex Math image. and LaTex Math image.) and the previous iteration of the current layer (LaTex Math image.). Here is the exact equation for non-LSTM cells:
LaTex Math image.
where LaTex Math image. are the indices of the timestamp and the layer of the cell being executed.
And here is the equation for LSTM cells:
ERROR processing math
where LaTex Math image. are the indices of the timestamp and the layer of the cell being executed.

Cell Functions

The RNN API provides four cell functions:
  • Vanilla RNN, a single-gate recurrent cell,
  • LSTM, a four-gate long short-term memory cell,
  • GRU, a three-gate gated recurrent unit cell,
  • Linear-before-reset GRU, a three-gate recurrent unit cell with the linear layer before the reset gate.
Vanilla RNN
A single-gate recurrent cell initialized with dnnl::vanilla_rnn_forward::desc::desc() or dnnl::vanilla_rnn_forward::desc::desc() as in the following example.
auto vanilla_rnn_desc = dnnl::vanilla_rnn_forward::desc( aprop, activation, direction, src_layer_desc, src_iter_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc);
The Vanilla RNN cell supports the ReLU, Tanh and Sigmoid activation functions. The following equations defines the mathematical operation performed by the Vanilla RNN cell for the forward pass:
ERROR processing math
LSTM
LSTM (or Vanilla LSTM)
A four-gate long short-term memory recurrent cell initialized with dnnl::lstm_forward::desc::desc() or dnnl::lstm_backward::desc::desc() as in the following example.
auto lstm_desc = lstm_forward::desc( aprop, direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_h_desc, dst_iter_c_desc);
Note that for all tensors with a dimension depending on the gate number, we implicitly require the order of these gates to be
i
,
f
, LaTex Math image., and
o
. The following equation gives the mathematical description of these gates and output for the forward pass:
ERROR processing math
where LaTex Math image. are stored in LaTex Math image., LaTex Math image. are stored in LaTex Math image. and LaTex Math image. are stored in LaTex Math image..
In order for the dimensions to be consistent, we require LaTex Math image..
LSTM with Peephole
A four-gate long short-term memory recurrent cell with peephole initialized with dnnl::lstm_forward::desc::desc() or dnnl::lstm_backward::desc::desc() as in the following example.
auto lstm_desc = dnnl::lstm_forward::desc( aprop, direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc, weights_layer_desc, weights_iter_desc, weights_peephole_desc, bias_desc, dst_layer_desc, dst_iter_h_desc, dst_iter_c_desc);
Similarly to vanilla LSTM, we implicitly require the order of the gates to be
i
,
f
, LaTex Math image., and
o
for all tensors with a dimension depending on the gates. For peephole weights, the gates order is
i
,
f
,
o
. The following equation gives the mathematical description of these gates and output for the forward pass:
ERROR processing math
where LaTex Math image. are stored in
weights_peephole
, and the other parameters are the same as in vanilla LSTM.
If the
weights_peephole_desc
passed to the operation descriptor constructor is a zero memory desciptor, the primitive will behave the same as in LSTM primitive without peephole.
LSTM with Projection (or LSTMP)
A four-gate long short-term memory recurrent cell with projection initialized with dnnl::lstm_forward::desc::desc() or dnnl::lstm_backward::desc::desc() as in the following example.
auto lstm_desc = dnnl::lstm_forward::desc( aprop, direction, src_layer_desc, src_iter_h_desc, src_iter_c_desc, weights_layer_desc, weights_iter_desc, weights_peephole_desc, weights_projection_desc, bias_desc, dst_layer_desc, dst_iter_h_desc, dst_iter_c_desc);
Similarly to vanilla LSTM, we implicitly require the order of the gates to be
i
,
f
, LaTex Math image., and
o
for all tensors with a dimension depending on the gates. The following equation gives the mathematical description of these gates and output for the forward pass (for simplicity, LSTM without peephole is shown):
ERROR processing math
where LaTex Math image. is stored in
weights_projection
, and the other parameters are the same as in vanilla LSTM.
If the
weights_projection_desc
passed to the operation descriptor constructor is a zero memory desciptor, the primitive will behave the same as in LSTM primitive without projection.
GRU
A three-gate gated recurrent unit cell, initialized with dnnl::gru_forward::desc::desc() or dnnl::gru_backward::desc::desc() as in the following example.
auto gru_desc = dnnl::gru_forward::desc( aprop, direction, src_layer_desc, src_iter_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc);
Note that for all tensors with a dimension depending on the gate number, we implicitly require the order of these gates to be
u
,
r
, and
o
. The following equation gives the mathematical definition of these gates.
ERROR processing math
where LaTex Math image. are in LaTex Math image., LaTex Math image. are in LaTex Math image., and LaTex Math image. are stored in LaTex Math image..
If you need to replace u_t by (1-u_t) when computing h_t, you can achieve this by multiplying LaTex Math image., LaTex Math image. and LaTex Math image. by LaTex Math image.. This is possible as LaTex Math image., and LaTex Math image..
Linear-Before-Reset GRU
A three-gate gated recurrent unit cell with linear layer applied before the reset gate, initialized with dnnl::lbr_gru_forward::desc::desc() or dnnl::lbr_gru_backward::desc::desc() as in the following example.
auto lbr_gru_desc = dnnl::lbr_gru_forward::desc( aprop, direction, src_layer_desc, src_iter_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc);
The following equation describes the mathematical behavior of the Linear-Before-Reset GRU cell.
ERROR processing math
Note that for all tensors with a dimension depending on the gate number, except the bias, we implicitly require the order of these gates to be
u
,
r
, and
o
. For the LaTex Math image. tensor, we implicitly require the order of the gates to be
u
,
r
,
o
, and u .
If you need to replace u_t by (1-u_t) when computing h_t, you can achieve this by multiplying LaTex Math image., LaTex Math image. and LaTex Math image. by LaTex Math image.. This is possible as LaTex Math image., and LaTex Math image..
AUGRU
A three-gate gated recurrent unit cell, initialized with dnnl::augru_forward::desc::desc() or dnnl::augru_backward::desc::desc() as in the following example.
auto augru_desc = dnnl::augru_forward::desc( aprop, direction, src_layer_desc, src_iter_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc);
Note that for all tensors with a dimension depending on the gate number, we implicitly require the order of these gates to be
u
,
r
, and
o
. The following equation gives the mathematical definition of these gates.
ERROR processing math
where LaTex Math image. are in LaTex Math image., LaTex Math image. are in LaTex Math image., and LaTex Math image. are stored in LaTex Math image..
If you need to replace u_t by (1-u_t) when computing h_t, you can achieve this by multiplying LaTex Math image., LaTex Math image. and LaTex Math image. by LaTex Math image.. This is possible as LaTex Math image., and LaTex Math image..
Linear-Before-Reset GRU
A three-gate gated recurrent unit cell with linear layer applied before the reset gate, initialized with dnnl::lbr_augru_forward::desc::desc() or dnnl::lbr_augru_backward::desc::desc() as in the following example.
auto lbr_augru_desc = dnnl::lbr_augru_forward::desc( aprop, direction, src_layer_desc, src_iter_desc, weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc);
The following equation describes the mathematical behavior of the Linear-Before-Reset GRU cell.
ERROR processing math
Note that for all tensors with a dimension depending on the gate number, except the bias, we implicitly require the order of these gates to be
u
,
r
, and
o
. For the LaTex Math image. tensor, we implicitly require the order of the gates to be
u
,
r
,
o
, and u .
If you need to replace u_t by (1-u_t) when computing h_t, you can achieve this by multiplying LaTex Math image., LaTex Math image. and LaTex Math image. by LaTex Math image.. This is possible as LaTex Math image., and LaTex Math image..

Considerations for Training

When using the RNN API for training, the forward pass should use the
forward_training
propagation kind, and a workspace should be passed to both the forward pass and the backward pass. Note that after executing the backward pass, the workspace is no more valid and should be populated once again by another forward pass.
The RNN primitive backward pass accumulates gradients to its weight outputs (namely LaTex Math image., LaTex Math image., LaTex Math image., LaTex Math image., LaTex Math image.). Hence, these tensors should be properly initialized to zero before their first use, and can be reused across calls to accumulate gradients if need be.

Execution Arguments

When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output
Execution argument index
LaTex Math image.
DNNL_ARG_SRC_LAYER
srclayerattention
DNNL_ARG_SRC_LAYER_ATTENTION
LaTex Math image.
DNNL_ARG_SRC_ITER
LaTex Math image.
DNNL_ARG_SRC_ITER_C
LaTex Math image.
DNNL_ARG_WEIGHTS_LAYER
LaTex Math image.
DNNL_ARG_WEIGHTS_ITER
LaTex Math image.
DNNL_ARG_WEIGHTS_PEEPHOLE
LaTex Math image.
DNNL_ARG_WEIGHTS_PROJECTION
LaTex Math image.
DNNL_ARG_BIAS
LaTex Math image.
DNNL_ARG_DST_LAYER
LaTex Math image.
DNNL_ARG_DST_ITER
LaTex Math image.
DNNL_ARG_DST_ITER_C
LaTex Math image.
DNNL_WORKSPACE
LaTex Math image.
DNNL_ARG_DIFF_SRC_LAYER
diffsrclayerattention
DNNL_ARG_DIFF_SRC_LAYER_ATTENTION
LaTex Math image.
DNNL_ARG_DIFF_SRC_ITER
LaTex Math image.
DNNL_ARG_DIFF_SRC_ITER_C
LaTex Math image.
DNNL_ARG_DIFF_WEIGHTS_LAYER
LaTex Math image.
DNNL_ARG_DIFF_WEIGHTS_ITER
LaTex Math image.
DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
LaTex Math image.
DNNL_ARG_DIFF_WEIGHTS_PROJECTION
LaTex Math image.
DNNL_ARG_DIFF_BIAS
LaTex Math image.
DNNL_ARG_DIFF_DST_LAYER
LaTex Math image.
DNNL_ARG_DIFF_DST_ITER
LaTex Math image.
DNNL_ARG_DIFF_DST_ITER_C

Implementation Details

Data Type Support
The following table lists the combination of data types supported by the RNN primitive for each input and output memory object.
Propagation
Cell Function
Input data
Recurrent data (1)
Weights
Bias
Output Data
Forward / Backward
All
f32
f32
f32
f32
f32
Forward / Backward (2)
All (3)
bf16
bf16
bf16
f32
bf16
Forward
All (3)
f16
f16
f16
f16
f16
Forward inference
Vanilla LSTM, LSTMP and GRU
u8
u8
s8
f32
u8, f32
Forward inference
Vanilla LSTM, LSTMP
s8
s8
s8
f32
s8, f32
  1. With LSTM and Peephole LSTM cells, the cell state datatype is f32, except for the f16 configuration.
  2. In backward propagation, all
    diff_*
    tensors are in f32.
  3. Projection LSTM is not supported.
There might be hardware and/or implementation specific restrictions. Check Implementation Limitations section below.
Data Representation
In the oneDNN programming model, the RNN primitive is one of a few that support the placeholder memory format dnnl::memory::format_tag::any (shortened to
any
from now on) and can define data and weight memory objects format based on the primitive parameters.
The following table summarizes the data layouts supported by the RNN primitive.
Propagation
Input/Output Data
Recurrent Data
Layer and Iteration Weights
Peephole Weights and Bias
Projection LSTM Weights
Forward / Backward
Forward
Backward
While an RNN primitive can be created with memory formats specified explicitly, the performance is likely to be sub-optimal. When using
any
it is necessary to first create an RNN primitive descriptor and then query it for the actual data and weight memory objects formats.
The RNN primitive supports padded tensors and views. So even if two memory descriptors share the same data layout, they might still be different.
Post-Ops and Attributes
Currently post-ops and attributes are only used by the int8 variants of LSTM and GRU. See the markdown RNN int8 inference example for more details on how to use and set these quantization parameters.

Implementation Limitations

  1. Refer to Data Types for limitations related to data types support.
  2. CPU
    • Bias must always be present (that is, the corresponding memory descriptor argument cannot be zero memory descriptor when the RNN operation descriptor is initialized).
    • oneDNN supports s8 as input data only on systems with Advanced Matrix Extension(AMX) support.
  3. GPU
    • No support for AUGRU.
    • No support for Peephole LSTM and Projection LSTM.
    • Bias must always be present (that is, the corresponding memory descriptor argument cannot be zero memory descriptor when the RNN operation descriptor is initialized).

Example

This C++ API example demonstrates how to create and execute an LSTM RNN primitive in forward training propagation mode.
Key optimizations included in this example:
  • Creation of optimized memory format from the primitive descriptor.

Product and Performance Information

1

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex.