Visible to Intel only — GUID: GUID-14AB4BCA-3048-403E-8281-ACEFADDE5F17
Visible to Intel only — GUID: GUID-14AB4BCA-3048-403E-8281-ACEFADDE5F17
RNN
Overview
A primitive to compute recurrent neural network layers. More…
// enums
enum dnnl_rnn_direction_t;
enum dnnl_rnn_flags_t;
enum dnnl::rnn_direction;
enum dnnl::rnn_flags;
// structs
struct dnnl::augru_backward;
struct dnnl::augru_forward;
struct dnnl::gru_backward;
struct dnnl::gru_forward;
struct dnnl::lbr_augru_backward;
struct dnnl::lbr_augru_forward;
struct dnnl::lbr_gru_backward;
struct dnnl::lbr_gru_forward;
struct dnnl::lstm_backward;
struct dnnl::lstm_forward;
struct dnnl::rnn_primitive_desc_base;
struct dnnl::vanilla_rnn_backward;
struct dnnl::vanilla_rnn_forward;
// global functions
dnnl_rnn_flags_t dnnl::convert_to_c(rnn_flags flags);
dnnl_rnn_direction_t dnnl::convert_to_c(rnn_direction dir);
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
float alpha,
float beta,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
float alpha,
float beta,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_src_iter_c_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_weights_peephole_desc,
const_dnnl_memory_desc_t diff_weights_projection_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
const_dnnl_memory_desc_t diff_dst_iter_c_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_attention_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
);
dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_attention_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
);
Detailed Documentation
A primitive to compute recurrent neural network layers.
See also:
RNN in developer guide
Global Functions
dnnl_rnn_flags_t dnnl::convert_to_c(rnn_flags flags)
Converts RNN cell flags enum value from C++ API to C API type.
Parameters:
flags |
C++ API RNN cell flags enum value. |
Returns:
Corresponding C API RNN cell flags enum value.
dnnl_rnn_direction_t dnnl::convert_to_c(rnn_direction dir)
Converts RNN direction enum value from C++ API to C API type.
Parameters:
dir |
C++ API RNN direction enum value. |
Returns:
Corresponding C API RNN direction enum value.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
float alpha,
float beta,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for vanilla RNN forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc,
bias_desc,
dst_iter_desc.
This would then indicate that the RNN forward propagation primitive should not use them and should default to zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
activation |
Activation kind. Possible values are dnnl_eltwise_relu, dnnl_eltwise_tanh or dnnl_eltwise_logistic. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
flags |
Unused. |
alpha |
Negative slope if activation is dnnl_eltwise_relu. |
beta |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
float alpha,
float beta,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for vanilla RNN backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with diff_src_iter_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with diff_dst_iter_desc.
This would then indicate that the RNN backward propagation primitive should not use the respective data and should use zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Must be dnnl_backward. |
activation |
Activation kind. Possible values are dnnl_eltwise_relu, dnnl_eltwise_tanh or dnnl_eltwise_logistic. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
diff_src_layer_desc |
Memory descriptor for the diff of input vector. |
diff_src_iter_desc |
Memory descriptor for the diff of input recurrent hidden state vector. |
diff_weights_layer_desc |
Memory descriptor for the diff of weights applied to the layer input. |
diff_weights_iter_desc |
Memory descriptor for the diff of weights applied to the recurrent input. |
diff_bias_desc |
Diff bias memory descriptor. |
diff_dst_layer_desc |
Memory descriptor for the diff of output vector. |
diff_dst_iter_desc |
Memory descriptor for the diff of output recurrent hidden state vector. |
flags |
Unused. |
alpha |
Negative slope if activation is dnnl_eltwise_relu. |
beta |
Unused. |
hint_fwd_pd |
Primitive descriptor for a respective forward propagation primitive. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for an LSTM forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with src_iter_c_desc,
weights_peephole_desc,
bias_desc,
dst_iter_desc together with dst_iter_c_desc.
This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.
The weights_projection_desc could either be NULL or point to a zero memory descriptor. This would then indicate that the LSTM doesn’t have recurrent projection layer.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc |
Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
weights_peephole_desc |
Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula). |
weights_projection_desc |
Memory descriptor for the weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc |
Memory descriptor for the output recurrent cell state vector. |
flags |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_src_iter_c_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_weights_peephole_desc,
const_dnnl_memory_desc_t diff_weights_projection_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
const_dnnl_memory_desc_t diff_dst_iter_c_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for an LSTM backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with src_iter_c_desc, diff_src_iter_desc, and diff_src_iter_c_desc,
weights_peephole_desc together with diff_weights_peephole_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with dst_iter_c_desc, diff_dst_iter_desc, and diff_dst_iter_c_desc.
This would then indicate that the LSTM backward propagation primitive should not use them and should default to zero values instead.
The weights_projection_desc together with diff_weights_projection_desc could either be NULL or point to a zero memory descriptor. This would then indicate that the LSTM doesn’t have recurrent projection layer.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Must be dnnl_backward. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc |
Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
weights_peephole_desc |
Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula). |
weights_projection_desc |
Memory descriptor for the weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc |
Memory descriptor for the output recurrent cell state vector. |
diff_src_layer_desc |
Memory descriptor for the diff of input vector. |
diff_src_iter_desc |
Memory descriptor for the diff of input recurrent hidden state vector. |
diff_src_iter_c_desc |
Memory descriptor for the diff of input recurrent cell state vector. |
diff_weights_layer_desc |
Memory descriptor for the diff of weights applied to the layer input. |
diff_weights_iter_desc |
Memory descriptor for the diff of weights applied to the recurrent input. |
diff_weights_peephole_desc |
Memory descriptor for the diff of weights applied to the cell states (according to the Peephole LSTM formula). |
diff_weights_projection_desc |
Memory descriptor for the diff of weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
diff_bias_desc |
Diff bias memory descriptor. |
diff_dst_layer_desc |
Memory descriptor for the diff of output vector. |
diff_dst_iter_desc |
Memory descriptor for the diff of output recurrent hidden state vector. |
diff_dst_iter_c_desc |
Memory descriptor for the diff of output recurrent cell state vector. |
flags |
Unused. |
hint_fwd_pd |
Primitive descriptor for a respective forward propagation primitive. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
unsigned flags,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for GRU forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc,
bias_desc,
dst_iter_desc.
This would then indicate that the GRU forward propagation primitive should not use them and should default to zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
flags |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
dnnl_primitive_desc_t* primitive_desc,
dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind,
dnnl_rnn_direction_t direction,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr
)
Creates a primitive descriptor for GRU backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with diff_src_iter_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with diff_dst_iter_desc.
This would then indicate that the GRU backward propagation primitive should not use them and should default to zero values instead.
Parameters: