Developer Guide and Reference

  • 2022.1
  • 04/11/2022
  • Public Content
Contents

Layer Normalization

General

The layer normalization primitive performs a forward or backward layer normalization operation on a 2-5D data tensor.
Forward
The layer normalization operation performs normalization over the last logical axis of the data tensor and is defined by the following formulas. We show formulas only for 3D data, which are straightforward to generalize to cases of higher dimensions. Variable names follow the standard Naming Conventions.
LaTex Math image.
where
Mean and variance are computed at runtime or provided by a user. When mean and variance are computed at runtime, the following formulas are used:
  • LaTex Math image.,
  • LaTex Math image..
The LaTex Math image. and LaTex Math image. tensors are considered learnable.
Difference Between Forward Training and Forward Inference
Backward
The backward propagation computes LaTex Math image., LaTex Math image., and LaTex Math image. based on LaTex Math image., LaTex Math image., LaTex Math image., LaTex Math image., LaTex Math image., and LaTex Math image..
The tensors marked with an asterisk are used only when the primitive is configured to use LaTex Math image., and LaTex Math image. (i.e., dnnl_use_scaleshift, dnnl_use_scale or dnnl_use_shift are set).

Execution Arguments

Depending on the flags and propagation kind, the layer normalization primitive requires different inputs and outputs. For clarity, a summary is shown below.
Inputs
: LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Same as for dnnl_backward
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Same as for dnnl_backward
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Not supported
Inputs
: LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image.
Not supported
Inputs
: LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image.
Not supported
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image.
Inputs
: LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image. , LaTex Math image.
Outputs
: LaTex Math image. , LaTex Math image. , LaTex Math image.
Not supported
When executed, the inputs and outputs should be mapped to an execution argument index as specified by the following table.
Primitive input/output
Execution argument index
LaTex Math image.
DNNL_ARG_SRC
LaTex Math image.
DNNL_ARG_SCALE_SHIFT
LaTex Math image.
DNNL_ARG_SCALE
LaTex Math image.
DNNL_ARG_SHIFT
mean ( LaTex Math image. )
DNNL_ARG_MEAN
variance ( LaTex Math image. )
DNNL_ARG_VARIANCE
LaTex Math image.
DNNL_ARG_DST
LaTex Math image.
DNNL_ARG_DIFF_DST
LaTex Math image.
DNNL_ARG_DIFF_SRC
LaTex Math image. , LaTex Math image.
DNNL_ARG_DIFF_SCALE_SHIFT
LaTex Math image.
DNNL_ARG_DIFF_SCALE
LaTex Math image.
DNNL_ARG_DIFF_SHIFT

Implementation Details

General Notes
  1. The different flavors of the primitive are partially controlled by the
    flags
    parameter that is passed to the operation descriptor initialization function (e.g., dnnl::layer_normalization_forward::desc::desc()). Multiple flags can be set using the bitwise OR operator (
    |
    ). Flag dnnl_use_scaleshift cannot be mixed with dnnl_use_scale or dnnl_use_shift.
  2. For forward propagation, the mean and variance might be either computed at runtime (in which case they are outputs of the primitive) or provided by a user (in which case they are inputs). In the latter case, a user must set the dnnl_use_global_stats flag. For the backward propagation, the mean and variance are always input parameters.
  3. The memory format and data type for
    src
    and
    dst
    are assumed to be the same, and in the API they are typically referred to as
    data
    (e.g., see
    data_desc
    in dnnl::layer_normalization_forward::desc::desc()). The same is true for
    diff_src
    and
    diff_dst
    . The corresponding memory descriptors are referred to as
    diff_data_desc
    .
  4. Both forward and backward propagation support in-place operations, meaning that LaTex Math image. can be used as input and output for forward propagation, and LaTex Math image. can be used as input and output for backward propagation. In case of an in-place operation, the original data will be overwritten. Note, however, that backward propagation requires original LaTex Math image., hence the corresponding forward propagation should not be performed in-place.
Data Type Support
The operation supports the following combinations of data types:
Propagation
Source / Destination
Mean / Variance / ScaleShift
forward / backward
f32, bf16
f32
forward
f16
f32
Data Representation
Mean and Variance
The mean (LaTex Math image.) and variance (LaTex Math image.) are separate tensors with number of dimensions equal to (LaTex Math image.) and size LaTex Math image..
The corresponding memory object can have an arbitrary memory format. Unless mean and variance are computed at runtime and not exposed (i.e., propagation kind is dnnl_forward_inference and dnnl_use_global_stats is not set), the user should provide a memory descriptor for statistics when initializing the layer normalization descriptor. For best performance, it is advised to use the memory format that follows the data memory format; i.e., if the data format is dnnl_tnc, the best performance can be expected for statistics with the dnnl_tn format and suboptimal for statistics with the dnnl_nt format.
Scale and Shift
If dnnl_use_scaleshift is used, the scale (LaTex Math image.) and shift (LaTex Math image.) are combined in a single 2D tensor of shape LaTex Math image..
If dnnl_use_scale or dnnl_use_shift are used, the scale (LaTex Math image.) and shift (LaTex Math image.) are separate 1D tensors of shape LaTex Math image..
The format of the corresponding memory object must be dnnl_nc (dnnl_ab).
Source, Destination, and Their Gradients
The layer normalization primitive works with an arbitrary data tensor; however, it was designed for RNN data tensors (i.e., dnnl_nc, dnnl_tnc, dnnl_ldnc). Unlike CNN data tensors, RNN data tensors have a single feature dimension. Layer normalization performs normalization over the last logical dimension (feature dimension for RNN tensors) across non-feature dimensions.
The layer normalization primitive is optimized for the following memory formats:
Logical tensor
Implementations optimized for memory formats
NC
TNC
LDNC

Performance Tips

  1. For data tensors (
    src
    ,
    dst
    ,
    diff_src
    ,
    diff_dst
    ), use memory formats for which the last logical axis is the last in the physical memory layout.
  2. For
    mean
    /
    variance
    , use the memory format that follows the data memory format; i.e., if the data format is dnnl_tnc, the best performance can be expected for statistics with dnnl_tn and suboptimal for statistics with the dnnl_nt format.
  3. For backward propagation, use the same memory format for
    src
    ,
    diff_dst
    , and
    diff_src
    (the format of
    diff_dst
    and
    diff_src
    are always the same because of the API). Different formats are functionally supported but lead to highly suboptimal performance.
  4. Use in-place operations whenever possible (see caveats in General Notes).

Example

This C++ API example demonstrates how to create and execute a Layer normalization primitive in forward propagation mode.
Key optimizations included in this example:
  • In-place primitive execution;
  • Creation of memory objects using the primitive descriptor.

Product and Performance Information

1

Performance varies by use, configuration and other factors. Learn more at www.Intel.com/PerformanceIndex.