Visible to Intel only — GUID: GUID-5B3C86F8-1D5C-41DA-8ED7-C5454A8B6CAB
Visible to Intel only — GUID: GUID-5B3C86F8-1D5C-41DA-8ED7-C5454A8B6CAB
RNN
Overview
A primitive to compute recurrent neural network layers. More…
// enums enum dnnl_rnn_direction_t; enum dnnl_rnn_flags_t; enum dnnl::rnn_direction; enum dnnl::rnn_flags; // structs struct dnnl::augru_backward; struct dnnl::augru_forward; struct dnnl::gru_backward; struct dnnl::gru_forward; struct dnnl::lbr_augru_backward; struct dnnl::lbr_augru_forward; struct dnnl::lbr_gru_backward; struct dnnl::lbr_gru_forward; struct dnnl::lstm_backward; struct dnnl::lstm_forward; struct dnnl::rnn_primitive_desc_base; struct dnnl::vanilla_rnn_backward; struct dnnl::vanilla_rnn_forward; // global functions dnnl_rnn_flags_t dnnl::convert_to_c(rnn_flags flags); dnnl_rnn_direction_t dnnl::convert_to_c(rnn_direction dir); dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha, float beta, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t src_iter_c_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t weights_peephole_desc, const_dnnl_memory_desc_t weights_projection_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t src_iter_c_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t weights_peephole_desc, const_dnnl_memory_desc_t weights_projection_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t dst_iter_c_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_src_iter_c_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_weights_peephole_desc, const_dnnl_memory_desc_t diff_weights_projection_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t attention_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t attention_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_attention_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t attention_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, const_dnnl_primitive_attr_t attr ); dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t attention_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_attention_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr );
Detailed Documentation
A primitive to compute recurrent neural network layers.
See also:
RNN in developer guide
Global Functions
dnnl_rnn_flags_t dnnl::convert_to_c(rnn_flags flags)
Converts RNN cell flags enum value from C++ API to C API type.
Parameters:
flags |
C++ API RNN cell flags enum value. |
Returns:
Corresponding C API RNN cell flags enum value.
dnnl_rnn_direction_t dnnl::convert_to_c(rnn_direction dir)
Converts RNN direction enum value from C++ API to C API type.
Parameters:
dir |
C++ API RNN direction enum value. |
Returns:
Corresponding C API RNN direction enum value.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha, float beta, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for vanilla RNN forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc,
bias_desc,
dst_iter_desc.
This would then indicate that the RNN forward propagation primitive should not use them and should default to zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
activation |
Activation kind. Possible values are dnnl_eltwise_relu, dnnl_eltwise_tanh or dnnl_eltwise_logistic. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
flags |
Unused. |
alpha |
Negative slope if activation is dnnl_eltwise_relu. |
beta |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for vanilla RNN backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with diff_src_iter_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with diff_dst_iter_desc.
This would then indicate that the RNN backward propagation primitive should not use the respective data and should use zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Must be dnnl_backward. |
activation |
Activation kind. Possible values are dnnl_eltwise_relu, dnnl_eltwise_tanh or dnnl_eltwise_logistic. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
diff_src_layer_desc |
Memory descriptor for the diff of input vector. |
diff_src_iter_desc |
Memory descriptor for the diff of input recurrent hidden state vector. |
diff_weights_layer_desc |
Memory descriptor for the diff of weights applied to the layer input. |
diff_weights_iter_desc |
Memory descriptor for the diff of weights applied to the recurrent input. |
diff_bias_desc |
Diff bias memory descriptor. |
diff_dst_layer_desc |
Memory descriptor for the diff of output vector. |
diff_dst_iter_desc |
Memory descriptor for the diff of output recurrent hidden state vector. |
flags |
Unused. |
alpha |
Negative slope if activation is dnnl_eltwise_relu. |
beta |
Unused. |
hint_fwd_pd |
Primitive descriptor for a respective forward propagation primitive. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t src_iter_c_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t weights_peephole_desc, const_dnnl_memory_desc_t weights_projection_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for an LSTM forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with src_iter_c_desc,
weights_peephole_desc,
bias_desc,
dst_iter_desc together with dst_iter_c_desc.
This would then indicate that the LSTM forward propagation primitive should not use them and should default to zero values instead.
The weights_projection_desc could either be NULL or point to a zero memory descriptor. This would then indicate that the LSTM doesn’t have recurrent projection layer.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc |
Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
weights_peephole_desc |
Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula). |
weights_projection_desc |
Memory descriptor for the weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc |
Memory descriptor for the output recurrent cell state vector. |
flags |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t src_iter_c_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t weights_peephole_desc, const_dnnl_memory_desc_t weights_projection_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t dst_iter_c_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_src_iter_c_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_weights_peephole_desc, const_dnnl_memory_desc_t diff_weights_projection_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for an LSTM backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with src_iter_c_desc, diff_src_iter_desc, and diff_src_iter_c_desc,
weights_peephole_desc together with diff_weights_peephole_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with dst_iter_c_desc, diff_dst_iter_desc, and diff_dst_iter_c_desc.
This would then indicate that the LSTM backward propagation primitive should not use them and should default to zero values instead.
The weights_projection_desc together with diff_weights_projection_desc could either be NULL or point to a zero memory descriptor. This would then indicate that the LSTM doesn’t have recurrent projection layer.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Must be dnnl_backward. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
src_iter_c_desc |
Memory descriptor for the input recurrent cell state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
weights_peephole_desc |
Memory descriptor for the weights applied to the cell states (according to the Peephole LSTM formula). |
weights_projection_desc |
Memory descriptor for the weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
dst_iter_c_desc |
Memory descriptor for the output recurrent cell state vector. |
diff_src_layer_desc |
Memory descriptor for the diff of input vector. |
diff_src_iter_desc |
Memory descriptor for the diff of input recurrent hidden state vector. |
diff_src_iter_c_desc |
Memory descriptor for the diff of input recurrent cell state vector. |
diff_weights_layer_desc |
Memory descriptor for the diff of weights applied to the layer input. |
diff_weights_iter_desc |
Memory descriptor for the diff of weights applied to the recurrent input. |
diff_weights_peephole_desc |
Memory descriptor for the diff of weights applied to the cell states (according to the Peephole LSTM formula). |
diff_weights_projection_desc |
Memory descriptor for the diff of weights applied to the hidden states to get the recurrent projection (according to the Projection LSTM formula). |
diff_bias_desc |
Diff bias memory descriptor. |
diff_dst_layer_desc |
Memory descriptor for the diff of output vector. |
diff_dst_iter_desc |
Memory descriptor for the diff of output recurrent hidden state vector. |
diff_dst_iter_c_desc |
Memory descriptor for the diff of output recurrent cell state vector. |
flags |
Unused. |
hint_fwd_pd |
Primitive descriptor for a respective forward propagation primitive. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for GRU forward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc,
bias_desc,
dst_iter_desc.
This would then indicate that the GRU forward propagation primitive should not use them and should default to zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |
Engine to use. |
prop_kind |
Propagation kind. Possible values are dnnl_forward_training and dnnl_forward_inference. |
direction |
RNN direction. See dnnl_rnn_direction_t for more info. |
src_layer_desc |
Memory descriptor for the input vector. |
src_iter_desc |
Memory descriptor for the input recurrent hidden state vector. |
weights_layer_desc |
Memory descriptor for the weights applied to the layer input. |
weights_iter_desc |
Memory descriptor for the weights applied to the recurrent input. |
bias_desc |
Bias memory descriptor. |
dst_layer_desc |
Memory descriptor for the output vector. |
dst_iter_desc |
Memory descriptor for the output recurrent hidden state vector. |
flags |
Unused. |
attr |
Primitive attributes (can be NULL). |
Returns:
dnnl_success on success and a status describing the error otherwise.
dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create( dnnl_primitive_desc_t* primitive_desc, dnnl_engine_t engine, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const_dnnl_memory_desc_t src_layer_desc, const_dnnl_memory_desc_t src_iter_desc, const_dnnl_memory_desc_t weights_layer_desc, const_dnnl_memory_desc_t weights_iter_desc, const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_layer_desc, const_dnnl_memory_desc_t dst_iter_desc, const_dnnl_memory_desc_t diff_src_layer_desc, const_dnnl_memory_desc_t diff_src_iter_desc, const_dnnl_memory_desc_t diff_weights_layer_desc, const_dnnl_memory_desc_t diff_weights_iter_desc, const_dnnl_memory_desc_t diff_bias_desc, const_dnnl_memory_desc_t diff_dst_layer_desc, const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd, const_dnnl_primitive_attr_t attr )
Creates a primitive descriptor for GRU backward propagation primitive.
The following arguments may either be NULL or point to a zero memory descriptor:
src_iter_desc together with diff_src_iter_desc,
bias_desc together with diff_bias_desc,
dst_iter_desc together with diff_dst_iter_desc.
This would then indicate that the GRU backward propagation primitive should not use them and should default to zero values instead.
Parameters:
primitive_desc |
Output primitive descriptor. |
engine |