Intel® oneAPI Deep Neural Network Developer Guide and Reference
                    
                        ID
                        768875
                    
                
                
                    Date
                    2/28/2024
                
                
                    Public
                
            A newer version of this document is available. Customers should click here to go to the newest version.
                                                
                                                
                                                    
                                                    
                                                        Abs
                                                    
                                                    
                                                
                                                    
                                                    
                                                        AbsBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Add
                                                    
                                                    
                                                
                                                    
                                                    
                                                        AvgPool
                                                    
                                                    
                                                
                                                    
                                                    
                                                        AvgPoolBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        BatchNormForwardTraining
                                                    
                                                    
                                                
                                                    
                                                    
                                                        BatchNormInference
                                                    
                                                    
                                                
                                                    
                                                    
                                                        BatchNormTrainingBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        BiasAdd
                                                    
                                                    
                                                
                                                    
                                                    
                                                        BiasAddBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Clamp
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ClampBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Concat
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Convolution
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ConvolutionBackwardData
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ConvolutionBackwardWeights
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ConvTranspose
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ConvTransposeBackwardData
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ConvTransposeBackwardWeights
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Dequantize
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Divide
                                                    
                                                    
                                                
                                                    
                                                    
                                                        DynamicDequantize
                                                    
                                                    
                                                
                                                    
                                                    
                                                        DynamicQuantize
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Elu
                                                    
                                                    
                                                
                                                    
                                                    
                                                        EluBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        End
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Exp
                                                    
                                                    
                                                
                                                    
                                                    
                                                        GELU
                                                    
                                                    
                                                
                                                    
                                                    
                                                        GELUBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        HardSigmoid
                                                    
                                                    
                                                
                                                    
                                                    
                                                        HardSigmoidBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        HardSwish
                                                    
                                                    
                                                
                                                    
                                                    
                                                        HardSwishBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Interpolate
                                                    
                                                    
                                                
                                                    
                                                    
                                                        InterpolateBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        LayerNorm
                                                    
                                                    
                                                
                                                    
                                                    
                                                        LayerNormBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        LeakyReLU
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Log
                                                    
                                                    
                                                
                                                    
                                                    
                                                        LogSoftmax
                                                    
                                                    
                                                
                                                    
                                                    
                                                        LogSoftmaxBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        MatMul
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Maximum
                                                    
                                                    
                                                
                                                    
                                                    
                                                        MaxPool
                                                    
                                                    
                                                
                                                    
                                                    
                                                        MaxPoolBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Minimum
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Mish
                                                    
                                                    
                                                
                                                    
                                                    
                                                        MishBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Multiply
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Pow
                                                    
                                                    
                                                
                                                    
                                                    
                                                        PReLU
                                                    
                                                    
                                                
                                                    
                                                    
                                                        PReLUBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Quantize
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Reciprocal
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceL1
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceL2
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceMax
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceMean
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceMin
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceProd
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReduceSum
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReLU
                                                    
                                                    
                                                
                                                    
                                                    
                                                        ReLUBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Reorder
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Round
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Select
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Sigmoid
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SigmoidBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SoftMax
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SoftMaxBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SoftPlus
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SoftPlusBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Sqrt
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SqrtBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Square
                                                    
                                                    
                                                
                                                    
                                                    
                                                        SquaredDifference
                                                    
                                                    
                                                
                                                    
                                                    
                                                        StaticReshape
                                                    
                                                    
                                                
                                                    
                                                    
                                                        StaticTranspose
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Subtract
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Tanh
                                                    
                                                    
                                                
                                                    
                                                    
                                                        TanhBackward
                                                    
                                                    
                                                
                                                    
                                                    
                                                        TypeCast
                                                    
                                                    
                                                
                                                    
                                                    
                                                        Wildcard
                                                    
                                                    
                                                
                                            
                                        
                                                            
                                                            
                                                                
                                                                
                                                                    enum dnnl_alg_kind_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl_normalization_flags_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl_primitive_kind_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl_prop_kind_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl_query_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl::normalization_flags
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl::query
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl_exec_arg_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl_primitive
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl_primitive_desc
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::primitive
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl::primitive_desc
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl::primitive_desc_base
                                                                
                                                                
                                                            
                                                        
                                                    
                                                            
                                                            
                                                                
                                                                
                                                                    enum dnnl_rnn_direction_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl_rnn_flags_t
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl::rnn_direction
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    enum dnnl::rnn_flags
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::augru_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::augru_forward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::gru_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::gru_forward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lbr_augru_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lbr_augru_forward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lbr_gru_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lbr_gru_forward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lstm_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::lstm_forward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                
                                                                    struct dnnl::rnn_primitive_desc_base
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::vanilla_rnn_backward
                                                                
                                                                
                                                                
                                                            
                                                                
                                                                    struct dnnl::vanilla_rnn_forward
                                                                
                                                                
                                                                
                                                            
                                                        
                                                    cnn_inference_f32 c
This C API example demonstrates how to build an AlexNet neural network topology for forward-pass inference.
This C API example demonstrates how to build an AlexNet neural network topology for forward-pass inference.
/*******************************************************************************
* Copyright 2016-2022 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
// Required for posix_memalign
#define _POSIX_C_SOURCE 200112L
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "oneapi/dnnl/dnnl.h"
#include "example_utils.h"
#define BATCH 8
#define IC 3
#define OC 96
#define CONV_IH 227
#define CONV_IW 227
#define CONV_OH 55
#define CONV_OW 55
#define CONV_STRIDE 4
#define CONV_PAD 0
#define POOL_OH 27
#define POOL_OW 27
#define POOL_STRIDE 2
#define POOL_PAD 0
static size_t product(dnnl_dim_t *arr, size_t size) {
    size_t prod = 1;
    for (size_t i = 0; i < size; ++i)
        prod *= arr[i];
    return prod;
}
static void init_net_data(float *data, uint32_t dim, const dnnl_dim_t *dims) {
    if (dim == 1) {
        for (dnnl_dim_t i = 0; i < dims[0]; ++i) {
            data[i] = (float)(i % 1637);
        }
    } else if (dim == 4) {
        for (dnnl_dim_t in = 0; in < dims[0]; ++in)
            for (dnnl_dim_t ic = 0; ic < dims[1]; ++ic)
                for (dnnl_dim_t ih = 0; ih < dims[2]; ++ih)
                    for (dnnl_dim_t iw = 0; iw < dims[3]; ++iw) {
                        dnnl_dim_t indx = in * dims[1] * dims[2] * dims[3]
                                + ic * dims[2] * dims[3] + ih * dims[3] + iw;
                        data[indx] = (float)(indx % 1637);
                    }
    }
}
typedef struct {
    int nargs;
    dnnl_exec_arg_t *args;
} args_t;
static void prepare_arg_node(args_t *node, int nargs) {
    node->args = (dnnl_exec_arg_t *)malloc(sizeof(dnnl_exec_arg_t) * nargs);
    node->nargs = nargs;
}
static void free_arg_node(args_t *node) {
    free(node->args);
}
static void set_arg(dnnl_exec_arg_t *arg, int arg_idx, dnnl_memory_t memory) {
    arg->arg = arg_idx;
    arg->memory = memory;
}
static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims,
        dnnl_format_tag_t user_tag, dnnl_engine_t engine, float *data,
        dnnl_memory_t *memory) {
    dnnl_memory_desc_t user_md;
    CHECK(dnnl_memory_desc_create_with_tag(
            &user_md, dim, dims, dnnl_f32, user_tag));
    CHECK(dnnl_memory_create(memory, user_md, engine, DNNL_MEMORY_ALLOCATE));
    CHECK(dnnl_memory_desc_destroy(user_md));
    write_to_dnnl_memory(data, *memory);
}
dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in
        const_dnnl_memory_desc_t prim_memory_md, // in
        dnnl_engine_t prim_engine, // in: primitive's engine
        int dir_is_user_to_prim, // in: user -> prim or prim -> user
        dnnl_memory_t *prim_memory, // out: primitive's memory created
        dnnl_primitive_t *reorder, // out: reorder primitive created
        uint32_t *net_index, // primitive index in net (inc if reorder created)
        dnnl_primitive_t *net, args_t *net_args) { // net params
    const_dnnl_memory_desc_t user_memory_md;
    dnnl_memory_get_memory_desc(*user_memory, &user_memory_md);
    dnnl_engine_t user_mem_engine;
    dnnl_memory_get_engine(*user_memory, &user_mem_engine);
    if (!dnnl_memory_desc_equal(user_memory_md, prim_memory_md)) {
        CHECK(dnnl_memory_create(prim_memory, prim_memory_md, prim_engine,
                DNNL_MEMORY_ALLOCATE));
        dnnl_primitive_desc_t reorder_pd;
        if (dir_is_user_to_prim) {
            CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
                    user_memory_md, user_mem_engine, prim_memory_md,
                    prim_engine, NULL));
        } else {
            CHECK(dnnl_reorder_primitive_desc_create(&reorder_pd,
                    prim_memory_md, prim_engine, user_memory_md,
                    user_mem_engine, NULL));
        }
        CHECK(dnnl_primitive_create(reorder, reorder_pd));
        CHECK(dnnl_primitive_desc_destroy(reorder_pd));
        net[*net_index] = *reorder;
        prepare_arg_node(&net_args[*net_index], 2);
        set_arg(&net_args[*net_index].args[0], DNNL_ARG_FROM,
                dir_is_user_to_prim ? *user_memory : *prim_memory);
        set_arg(&net_args[*net_index].args[1], DNNL_ARG_TO,
                dir_is_user_to_prim ? *prim_memory : *user_memory);
        (*net_index)++;
    } else {
        *prim_memory = NULL;
        *reorder = NULL;
    }
    return dnnl_success;
}
void simple_net(dnnl_engine_kind_t engine_kind) {
    dnnl_engine_t engine;
    CHECK(dnnl_engine_create(&engine, engine_kind, 0));
    // build a simple net
    uint32_t n = 0;
    dnnl_primitive_t net[10];
    args_t net_args[10];
    const int ndims = 4;
    dnnl_dims_t net_src_sizes = {BATCH, IC, CONV_IH, CONV_IW};
    dnnl_dims_t net_dst_sizes = {BATCH, OC, POOL_OH, POOL_OW};
    float *net_src
            = (float *)malloc(product(net_src_sizes, ndims) * sizeof(float));
    float *net_dst
            = (float *)malloc(product(net_dst_sizes, ndims) * sizeof(float));
    init_net_data(net_src, ndims, net_src_sizes);
    memset(net_dst, 0, product(net_dst_sizes, ndims) * sizeof(float));
    // AlexNet: conv
    // {BATCH, IC, CONV_IH, CONV_IW} (x) {OC, IC, 11, 11} ->
    // {BATCH, OC, CONV_OH, CONV_OW}
    // strides: {CONV_STRIDE, CONV_STRIDE}
    dnnl_dims_t conv_user_src_sizes;
    for (int i = 0; i < ndims; i++)
        conv_user_src_sizes[i] = net_src_sizes[i];
    dnnl_dims_t conv_user_weights_sizes = {OC, IC, 11, 11};
    dnnl_dims_t conv_bias_sizes = {OC};
    dnnl_dims_t conv_user_dst_sizes = {BATCH, OC, CONV_OH, CONV_OW};
    dnnl_dims_t conv_strides = {CONV_STRIDE, CONV_STRIDE};
    dnnl_dims_t conv_dilation = {0, 0};
    dnnl_dims_t conv_padding = {CONV_PAD, CONV_PAD};
    float *conv_src = net_src;
    float *conv_weights = (float *)malloc(
            product(conv_user_weights_sizes, ndims) * sizeof(float));
    float *conv_bias
            = (float *)malloc(product(conv_bias_sizes, 1) * sizeof(float));
    init_net_data(conv_weights, ndims, conv_user_weights_sizes);
    init_net_data(conv_bias, 1, conv_bias_sizes);
    // create memory for user data
    dnnl_memory_t conv_user_src_memory, conv_user_weights_memory,
            conv_user_bias_memory;
    init_data_memory(ndims, conv_user_src_sizes, dnnl_nchw, engine, conv_src,
            &conv_user_src_memory);
    init_data_memory(ndims, conv_user_weights_sizes, dnnl_oihw, engine,
            conv_weights, &conv_user_weights_memory);
    init_data_memory(1, conv_bias_sizes, dnnl_x, engine, conv_bias,
            &conv_user_bias_memory);
    // create data descriptors for convolution w/ no specified format
    dnnl_memory_desc_t conv_src_md, conv_weights_md, conv_bias_md, conv_dst_md;
    CHECK(dnnl_memory_desc_create_with_tag(&conv_src_md, ndims,
            conv_user_src_sizes, dnnl_f32, dnnl_format_tag_any));
    CHECK(dnnl_memory_desc_create_with_tag(&conv_weights_md, ndims,
            conv_user_weights_sizes, dnnl_f32, dnnl_format_tag_any));
    CHECK(dnnl_memory_desc_create_with_tag(
            &conv_bias_md, 1, conv_bias_sizes, dnnl_f32, dnnl_x));
    CHECK(dnnl_memory_desc_create_with_tag(&conv_dst_md, ndims,
            conv_user_dst_sizes, dnnl_f32, dnnl_format_tag_any));
    // create a convolution
    dnnl_primitive_desc_t conv_pd;
    CHECK(dnnl_convolution_forward_primitive_desc_create(&conv_pd, engine,
            dnnl_forward, dnnl_convolution_direct, conv_src_md, conv_weights_md,
            conv_bias_md, conv_dst_md, conv_strides, conv_dilation,
            conv_padding, conv_padding, NULL));
    dnnl_memory_t conv_internal_src_memory, conv_internal_weights_memory,
            conv_internal_dst_memory;
    // create memory for dst data, we don't need reorder it to user data
    const_dnnl_memory_desc_t dst_md
            = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
    CHECK(dnnl_memory_create(
            &conv_internal_dst_memory, dst_md, engine, DNNL_MEMORY_ALLOCATE));
    // create reorder primitives between user data and convolution srcs
    // if required
    dnnl_primitive_t conv_reorder_src, conv_reorder_weights;
    const_dnnl_memory_desc_t src_md
            = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0);
    CHECK(prepare_reorder(&conv_user_src_memory, src_md, engine, 1,
            &conv_internal_src_memory, &conv_reorder_src, &n, net, net_args));
    const_dnnl_memory_desc_t weights_md
            = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0);
    CHECK(prepare_reorder(&conv_user_weights_memory, weights_md, engine, 1,
            &conv_internal_weights_memory, &conv_reorder_weights, &n, net,
            net_args));
    dnnl_memory_t conv_src_memory = conv_internal_src_memory
            ? conv_internal_src_memory
            : conv_user_src_memory;
    dnnl_memory_t conv_weights_memory = conv_internal_weights_memory
            ? conv_internal_weights_memory
            : conv_user_weights_memory;
    // finally create a convolution primitive
    dnnl_primitive_t conv;
    CHECK(dnnl_primitive_create(&conv, conv_pd));
    net[n] = conv;
    prepare_arg_node(&net_args[n], 4);
    set_arg(&net_args[n].args[0], DNNL_ARG_SRC, conv_src_memory);
    set_arg(&net_args[n].args[1], DNNL_ARG_WEIGHTS, conv_weights_memory);
    set_arg(&net_args[n].args[2], DNNL_ARG_BIAS, conv_user_bias_memory);
    set_arg(&net_args[n].args[3], DNNL_ARG_DST, conv_internal_dst_memory);
    n++;
    // AlexNet: relu
    // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
    float negative_slope = 0.0f;
    // create relu memory descriptor on dst memory descriptor
    // from previous primitive
    const_dnnl_memory_desc_t relu_src_md
            = dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
    const_dnnl_memory_desc_t relu_dst_md = relu_src_md;
    // create a relu
    dnnl_primitive_desc_t relu_pd;
    CHECK(dnnl_eltwise_forward_primitive_desc_create(&relu_pd, engine,
            dnnl_forward, dnnl_eltwise_relu, relu_src_md, relu_dst_md,
            negative_slope, 0, NULL));
    dnnl_memory_t relu_dst_memory;
    CHECK(dnnl_memory_create(
            &relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
    // finally create a relu primitive
    dnnl_primitive_t relu;
    CHECK(dnnl_primitive_create(&relu, relu_pd));
    net[n] = relu;
    prepare_arg_node(&net_args[n], 2);
    set_arg(&net_args[n].args[0], DNNL_ARG_SRC, conv_internal_dst_memory);
    set_arg(&net_args[n].args[1], DNNL_ARG_DST, relu_dst_memory);
    n++;
    // AlexNet: lrn
    // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, CONV_OH, CONV_OW}
    // local size: 5
    // alpha: 0.0001
    // beta: 0.75
    // k: 1.0
    uint32_t local_size = 5;
    float alpha = 0.0001f;
    float beta = 0.75f;
    float k = 1.0f;
    // create lrn src memory descriptor using dst memory descriptor
    //  from previous primitive
    const_dnnl_memory_desc_t lrn_src_md = relu_dst_md;
    const_dnnl_memory_desc_t lrn_dst_md = lrn_src_md;
    // create a lrn primitive descriptor
    dnnl_primitive_desc_t lrn_pd;
    CHECK(dnnl_lrn_forward_primitive_desc_create(&lrn_pd, engine, dnnl_forward,
            dnnl_lrn_across_channels, lrn_src_md, lrn_dst_md, local_size, alpha,
            beta, k, NULL));
    // create primitives for lrn dst and workspace memory
    dnnl_memory_t lrn_dst_memory;
    CHECK(dnnl_memory_create(
            &lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
    dnnl_memory_t lrn_ws_memory;
    const_dnnl_memory_desc_t lrn_ws_md
            = dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0);
    CHECK(dnnl_memory_create(
            &lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
    // finally create a lrn primitive
    dnnl_primitive_t lrn;
    CHECK(dnnl_primitive_create(&lrn, lrn_pd));
    net[n] = lrn;
    prepare_arg_node(&net_args[n], 3);
    set_arg(&net_args[n].args[0], DNNL_ARG_SRC, relu_dst_memory);
    set_arg(&net_args[n].args[1], DNNL_ARG_DST, lrn_dst_memory);
    set_arg(&net_args[n].args[2], DNNL_ARG_WORKSPACE, lrn_ws_memory);
    n++;
    // AlexNet: pool
    // {BATCH, OC, CONV_OH, CONV_OW} -> {BATCH, OC, POOL_OH, POOL_OW}
    // kernel: {3, 3}
    // strides: {POOL_STRIDE, POOL_STRIDE}
    // dilation: {0, 0}
    dnnl_dims_t pool_dst_sizes;
    for (int i = 0; i < ndims; i++)
        pool_dst_sizes[i] = net_dst_sizes[i];
    dnnl_dims_t pool_kernel = {3, 3};
    dnnl_dims_t pool_strides = {POOL_STRIDE, POOL_STRIDE};
    dnnl_dims_t pool_padding = {POOL_PAD, POOL_PAD};
    dnnl_dims_t pool_dilation = {0, 0};
    // create pooling memory descriptor on dst descriptor
    //  from previous primitive
    const_dnnl_memory_desc_t pool_src_md = lrn_dst_md;
    // create descriptors for dst pooling data
    dnnl_memory_desc_t pool_dst_any_md;
    CHECK(dnnl_memory_desc_create_with_tag(&pool_dst_any_md, ndims,
            pool_dst_sizes, dnnl_f32, dnnl_format_tag_any));
    // create memory for user data
    dnnl_memory_t pool_user_dst_memory;
    init_data_memory(ndims, pool_dst_sizes, dnnl_nchw, engine, net_dst,
            &pool_user_dst_memory);
    // create a pooling
    dnnl_primitive_desc_t pool_pd;
    CHECK(dnnl_pooling_forward_primitive_desc_create(&pool_pd, engine,
            dnnl_forward, dnnl_pooling_max, pool_src_md, pool_dst_any_md,
            pool_strides, pool_kernel, pool_dilation, pool_padding,
            pool_padding, NULL));
    // create memory for workspace
    dnnl_memory_t pool_ws_memory;
    const_dnnl_memory_desc_t pool_ws_md
            = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0);
    CHECK(dnnl_memory_create(
            &pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
    dnnl_memory_t pool_dst_memory;
    // create reorder primitives between user data and pooling dsts
    // if required
    dnnl_primitive_t pool_reorder_dst;
    dnnl_memory_t pool_internal_dst_memory;
    const_dnnl_memory_desc_t pool_dst_md
            = dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0);
    n += 1; // tentative workaround: preserve space for pooling that should
            // happen before the reorder
    CHECK(prepare_reorder(&pool_user_dst_memory, pool_dst_md, engine, 0,
            &pool_internal_dst_memory, &pool_reorder_dst, &n, net, net_args));
    n -= pool_reorder_dst ? 2 : 1;
    pool_dst_memory = pool_internal_dst_memory ? pool_internal_dst_memory
                                               : pool_user_dst_memory;
    // finally create a pooling primitive
    dnnl_primitive_t pool;
    CHECK(dnnl_primitive_create(&pool, pool_pd));
    net[n] = pool;
    prepare_arg_node(&net_args[n], 3);
    set_arg(&net_args[n].args[0], DNNL_ARG_SRC, lrn_dst_memory);
    set_arg(&net_args[n].args[1], DNNL_ARG_DST, pool_dst_memory);
    set_arg(&net_args[n].args[2], DNNL_ARG_WORKSPACE, pool_ws_memory);
    n++;
    if (pool_reorder_dst) n += 1;
    dnnl_stream_t stream;
    CHECK(dnnl_stream_create(&stream, engine, dnnl_stream_default_flags));
    for (uint32_t i = 0; i < n; ++i) {
        CHECK(dnnl_primitive_execute(
                net[i], stream, net_args[i].nargs, net_args[i].args));
    }
    CHECK(dnnl_stream_wait(stream));
    // clean-up
    for (uint32_t i = 0; i < n; ++i)
        free_arg_node(&net_args[i]);
    CHECK(dnnl_primitive_desc_destroy(conv_pd));
    CHECK(dnnl_primitive_desc_destroy(relu_pd));
    CHECK(dnnl_primitive_desc_destroy(lrn_pd));
    CHECK(dnnl_primitive_desc_destroy(pool_pd));
    dnnl_stream_destroy(stream);
    free(net_src);
    free(net_dst);
    dnnl_memory_desc_destroy(conv_src_md);
    dnnl_memory_desc_destroy(conv_weights_md);
    dnnl_memory_desc_destroy(conv_bias_md);
    dnnl_memory_desc_destroy(conv_dst_md);
    dnnl_memory_desc_destroy(pool_dst_any_md);
    dnnl_memory_destroy(conv_user_src_memory);
    dnnl_memory_destroy(conv_user_weights_memory);
    dnnl_memory_destroy(conv_user_bias_memory);
    dnnl_memory_destroy(conv_internal_src_memory);
    dnnl_memory_destroy(conv_internal_weights_memory);
    dnnl_memory_destroy(conv_internal_dst_memory);
    dnnl_primitive_destroy(conv_reorder_src);
    dnnl_primitive_destroy(conv_reorder_weights);
    dnnl_primitive_destroy(conv);
    free(conv_weights);
    free(conv_bias);
    dnnl_memory_destroy(relu_dst_memory);
    dnnl_primitive_destroy(relu);
    dnnl_memory_destroy(lrn_ws_memory);
    dnnl_memory_destroy(lrn_dst_memory);
    dnnl_primitive_destroy(lrn);
    dnnl_memory_destroy(pool_user_dst_memory);
    dnnl_memory_destroy(pool_internal_dst_memory);
    dnnl_memory_destroy(pool_ws_memory);
    dnnl_primitive_destroy(pool_reorder_dst);
    dnnl_primitive_destroy(pool);
    dnnl_engine_destroy(engine);
}
int main(int argc, char **argv) {
    dnnl_engine_kind_t engine_kind = parse_engine_kind(argc, argv);
    simple_net(engine_kind);
    printf("Example passed on %s.\n", engine_kind2str_upper(engine_kind));
    return 0;
}