Migrate Deformable Convolution Networks from CUDA* to SYCL* with Intel® Extension for PyTorch*

ID 812331
Updated 5/6/2024
Version 1.0
Public

author-image

By

Migrate Deformable Convolution Networks from CUDA* to SYCL* with Intel® Extension for PyTorch*

Contributors

From Intel Corporation: Huiyan Cao, Jie Lin, Jing Xu, Ying Hu, and Chao Yu

Introduction

In this document, we introduce a guide on migrating Deformable Convolutional Networks V2[1], which is used by CenterNet[2], from CUDA* to SYCL* using the Intel® DPC++ Compatibility Tool[3], and run CenterNet on an Intel GPU using Intel® Extension for PyTorch*[4].

CenterNet & Deformable Convolutional Networks

The CenterNet we are using is based on modeling an object as a single point — the center point of its bounding box. The detector uses keypoint estimation to find center points and regresses to all other object properties, such as size, 3D location, orientation, and even pose. CenterNet's center point-based approach is end-to-end differentiable, simpler, faster, and more accurate than corresponding bounding box-based detectors.

In this CenterNet, ResNet and DLA are modified to use deformable convolution layers with some operators implemented using CUDA*.

SYCL*

SYCL* offers a powerful and flexible programming model for implementing customized algorithms as extensions to AI frameworks. It simplifies the development process, provides abstraction from hardware details, and enables efficient use of heterogeneous computing resources:

  • Heterogeneous Computing: SYCL* is designed for heterogeneous computing, allowing developers to leverage the computational power of diverse accelerators. This is crucial in the field of AI, where different tasks may benefit from specialized hardware such as GPUs or FPGAs.

  • Portability: SYCL* is a portable programming model that lets developers to write code that can be executed across a variety of devices without modification. This portability is particularly valuable when working with different hardware architectures and accelerators commonly found in AI applications.

  • Parallelism and Performance: SYCL* lets developers express parallelism in algorithms, which is essential for optimizing performance in AI workloads. SYCL* provides flexibility in algorithm design and lets developers express parallelism and optimizations directly in the code. This is beneficial when tailoring algorithms to specific AI tasks, as it provides fine-grained control over the execution on different devices.

  • Standardization and Ecosystem: SYCL* is an open standard maintained by the Khronos Group, fostering a standardized programming model for heterogeneous systems. This standardization can lead to a more robust and interoperable ecosystem, with tools and libraries that support SYCL* development.

Intel® DPC++ Compatibility Tool

The Intel DPC++ Compatibility Tool assists in migrating your existing CUDA* code to SYCL* code. Here are several ways the tool is beneficial for this migration process:

  • Code Analysis and Syntax Mapping:The Compatibility Tool can analyze existing CUDA* codebases and identify potential areas of concern or incompatibility with SYCL*. While CUDA* and SYCL* have different syntaxes, many concepts are similar so the tool can automatically map CUDA* constructs to their equivalent SYCL* counterparts. This minimizes code-migration time to update a program written in CUDA* to a program written in SYCL*.

  • API Compatibility: CUDA* and SYCL* have distinct API sets for managing devices, memory, and other runtime features. The Compatibility Tool can help find differences in these APIs, and automatically generate SYCL* or Intel oneAPI Library code wherever possible. For those CUDA* APIs the Compatibility Tool could not handle automatically, it generates inline comments in the generated SYCL* code to guide developers where to manually update their code.

By providing these capabilities, the Intel DPC++ Compatibility Tool simplifies the process of migrating CUDA* code to SYCL*, making SYCL* more accessible for developers. The resulting SYCL* code lets developers use diverse hardware platforms for their parallel computing workloads.

Intel® Extension for PyTorch*

Intel® Extension for PyTorch* extends PyTorch with the latest performance optimizations for Intel hardware. Optimizations take advantage of Intel® Advanced Vector Extensions 512 (Intel® AVX-512) Vector Neural Network Instructions (VNNI), Intel® Advanced Matrix Extensions (Intel® AMX) on Intel CPUs, as well as Intel XeMatrix Extensions (XMX) AI engines on Intel discrete GPUs. Intel® Extension for PyTorch* offers GPU acceleration for Intel discrete GPUs through the PyTorch* xpu device.

Figure 1: Intel® Extension for PyTorch* Structure

Figure 1: Intel® Extension for PyTorch* Structure

C++ extension is a mechanism developed by PyTorch that lets you create customized and highly efficient PyTorch operators defined out-of-source, i.e., separate from the PyTorch backend. (For more details, see https://pytorch.org/tutorials/advanced/cpp_extension.html). Based on the PyTorch C++ extension mechanism, Intel® Extension for PyTorch* lets you create PyTorch operators with custom DPC++ kernels to run on the xpu device.

Migration

The following sections explain these four steps for the migration solution for Deformable Convolution Networks:

  1. Migrate CUDA* code of deformable convolution layers to SYCL* code using the Intel® DPC++ Compatibility Tool.
  2. Build SYCL* code of deformable convolution layers using DpcppBuildExtension.
  3. Use xpu device in CenterNet with Intel® Extension for PyTorch*.
  4. Align memory format between Python operators and customized SYCL* operators.

Figure 2: CUDA to SYCL Migration Solution

Figure 2: CUDA* to SYCL* Migration Solution

Migrate CUDA* Code

The original DCNv2 code in CenterNet is incompatible with the latest version of PyTorch, so we use DCNv2_latest instead.

Migrate CUDA code

Figure 3: Migrate CUDA* code

  • Build DCNv2_latest and generate log file
# create conda environment
conda create -n dcnv2_cuda python=3.9 -y
conda activate dcnv2_cuda

# install PyTorch according to your CUDA version, for example:
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia  

# clone repo
git clone https://github.com/lucasjinreal/DCNv2_latest

# build and generate log file
cd DCNv2_latest
python3 setup.py build develop 2>&1 | tee build_log.txt
# source environment after Intel® DPC++ Compatibility Tool is installed
source <oneapi root path>/dpcpp-ct/latest/env/vars.sh

# convert CUDA code to DPC++ code
intercept-build --parse-build-log build_log.txt
dpct -p=./ --in-root=./ --out-root=./sycl --keep-original-code

The generated DPC++ code is output to the sycl/src/cuda folder:

  DCNv2_latest
  |--- sycl
        |--- MainSourceFiles.yaml
        |--- src
              |--- cpu
              |--- cuda
              |     |---*.h
              |     |---*.dp.cpp
              |--- *.h
              |--- *.cpp
  • Convert CUDA* stream to SYCL* queue

    During this conversion, we found that c10::cuda::getCurrentCUDAStream() was not converted automatically and needs to be converted manually, as shown here:

    • Create utils.h in sycl/src/cuda

      #ifndef UTILS_H
      #define UTILS_H
      
      #include <xpu/Macros.h>
      #include <xpu/Stream.h>
      #include <c10/core/Device.h>
      
      inline sycl::queue& getCurrentXPUQueue() 
      {
        auto device_type = c10::DeviceType::XPU; 
        c10::impl::VirtualGuardImpl impl(device_type);
        c10::Stream dpcpp_stream = impl.getStream(impl.getDevice());
        return xpu::get_queue_from_stream(dpcpp_stream);
      }
      
      #endif
      
    • Replace all c10::cuda::getCurrentCUDAStream() with &getCurrentXPUQueue() in sycl/src/cuda

      For example, in sycl/src/cuda/dcn_v2_cuda.dp.cpp:

      // #include <ATen/cuda/CUDAContext.h>
      #include "utils.h"
      
      // modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), 
      modulated_deformable_im2col_cuda(&getCurrentXPUQueue(),
      
  • Check all places with You need to rewrite this code, comment all CUDA* related code.

    # comment all C10_CUDA_CHECK statement
    // C10_CUDA_CHECK(0);
    
    # comment all places with `xxx.is_cuda()`
    // ... xxx.iscuda(...) ...;
    
  • Merge DCNv2 SYCL* code into CenterNet

    • Clone CenterNet repo

      git clone https://github.com/xingyizhou/CenterNet
      
    • Remove CUDA* code of DCNv2

      cd CenterNet/src/lib/models/networks/DCNv2/src
      rm * -rf
      
    • Copy SYCL* code of DCNv2

      cp -r <DCNv2_latest Root Path>/sycl/src/cuda sycl
      cp <DCNv2_latest Root Path>/sycl/src/*.h .
      cp <DCNv2_latest Root Path>/sycl/src/*.cpp .
      
    • Modify CenterNet/src/lib/models/networks/DCNv2/src/dcn_v2.h as below:

      #pragma once
      
      // #include "cpu/vision.h"
      
      // #ifdef WITH_CUDA
      // #include "cuda/vision.h"
      // #endif
      
      #include "sycl/vision.h"
      
      at::Tensor
      dcn_v2_forward(const at::Tensor &input,
                    const at::Tensor &weight,
                    const at::Tensor &bias,
                    const at::Tensor &offset,
                    const at::Tensor &mask,
                    const int kernel_h,
                    const int kernel_w,
                    const int stride_h,
                    const int stride_w,
                    const int pad_h,
                    const int pad_w,
                    const int dilation_h,
                    const int dilation_w,
                    const int deformable_group)
      {
          return dcn_v2_cuda_forward(input, weight, bias, offset, mask,
                                      kernel_h, kernel_w,
                                      stride_h, stride_w,
                                      pad_h, pad_w,
                                      dilation_h, dilation_w,
                                      deformable_group);
      }
      
      std::vector<at::Tensor>
      dcn_v2_backward(const at::Tensor &input,
                      const at::Tensor &weight,
                      const at::Tensor &bias,
                      const at::Tensor &offset,
                      const at::Tensor &mask,
                      const at::Tensor &grad_output,
                      int kernel_h, int kernel_w,
                      int stride_h, int stride_w,
                      int pad_h, int pad_w,
                      int dilation_h, int dilation_w,
                      int deformable_group)
      {
          return dcn_v2_cuda_backward(input,
                                      weight,
                                      bias,
                                      offset,
                                      mask,
                                      grad_output,
                                      kernel_h, kernel_w,
                                      stride_h, stride_w,
                                      pad_h, pad_w,
                                      dilation_h, dilation_w,
                                      deformable_group);
      }
      
      std::tuple<at::Tensor, at::Tensor>
      dcn_v2_psroi_pooling_forward(const at::Tensor &input,
                                  const at::Tensor &bbox,
                                  const at::Tensor &trans,
                                  const int no_trans,
                                  const float spatial_scale,
                                  const int output_dim,
                                  const int group_size,
                                  const int pooled_size,
                                  const int part_size,
                                  const int sample_per_part,
                                  const float trans_std)
      {
          return dcn_v2_psroi_pooling_cuda_forward(input,
                                                      bbox,
                                                      trans,
                                                      no_trans,
                                                      spatial_scale,
                                                      output_dim,
                                                      group_size,
                                                      pooled_size,
                                                      part_size,
                                                      sample_per_part,
                                                      trans_std);
      }
      
      std::tuple<at::Tensor, at::Tensor>
      dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad,
                                    const at::Tensor &input,
                                    const at::Tensor &bbox,
                                    const at::Tensor &trans,
                                    const at::Tensor &top_count,
                                    const int no_trans,
                                    const float spatial_scale,
                                    const int output_dim,
                                    const int group_size,
                                    const int pooled_size,
                                    const int part_size,
                                    const int sample_per_part,
                                    const float trans_std)
      {
          return dcn_v2_psroi_pooling_cuda_backward(out_grad,
                                                      input,
                                                      bbox,
                                                      trans,
                                                      top_count,
                                                      no_trans,
                                                      spatial_scale,
                                                      output_dim,
                                                      group_size,
                                                      pooled_size,
                                                      part_size,
                                                      sample_per_part,
                                                      trans_std);
      }
      
    • Replace dcn_v2.py in CenterNet/src/lib/models/networks/DCNv2 from DCNv2_latest

      cd CenterNet/src/lib/models/networks/DCNv2
      cp <DCNv2_latest root path>/dcn_v2.py .
      
      # modify below import statement in dcn_v2.py after copy
      # import _ext as _backend
      from ._ext import dcn_v2 as _backend
      

Build DPC++ Extension for DCNv2

Build DPC++ Extension


Figure 4: Build DPC++ Extension

  • Install Intel® oneAPI Base Toolkit

    • Intel® oneAPI DPC++/C++ Compiler, Intel® DPC++ Compatibility Tool, Intel® oneAPI Math Kernel Library, and Intel® oneAPI Deep Neural Network Library are needed at least

    • source environment variables after installation

      source <oneapi root path>/setvars.sh
      
  • Create python environment

    • Create conda environment

      conda create -n centernet_xpu python=3.9 -y
      conda activate centernet_xpu
      
    • Check this installation guide to install the latest xpu version of Intel® Extension for PyTorch*.

    • Install other dependencies

cd CenterNet
pip install -r requirements.txt
  • Create CenterNet/src/lib/models/networks/DCNv2/setup.py

import os
import glob
from setuptools import find_packages, setup
import intel_extension_for_pytorch as ipex
from intel_extension_for_pytorch.xpu.cpp_extension import DPCPPExtension, DpcppBuildExtension


def get_extensions():
    this_dir = os.path.dirname(os.path.abspath(__file__))
    extensions_dir = os.path.join(this_dir, "src")

    main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
    source_sycl = glob.glob(os.path.join(extensions_dir, "sycl", "*.cpp"))
    sources = main_file + source_sycl
    sources = [os.path.join(extensions_dir, s) for s in sources]

    extension = DPCPPExtension

    ext_modules = [
        extension(
            name='_ext.dcn_v2',
            sources=sources,
            include_dirs=ipex.xpu.cpp_extension.include_paths(),
        )
    ]

    return ext_modules

setup(
    name="DCNv2_latest",
    version="0.1",
    author="lucasjinreal",
    url="https://github.com/lucasjinreal/DCNv2_latest",
    description="deformable convolutional networks v2",
    packages=find_packages(exclude=("configs", "tests")),
    ext_modules=get_extensions(),
    cmdclass={"build_ext": DpcppBuildExtension},
)
  • Build extension
mkdir _ext
python3 setup.py build develop

Use 'xpu' device in CenterNet with Intel® Extension for PyTorch*

  • Add --xpu parameter in CenterNet/src/lib/opts.py
class opts(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        ...
        self.parser.add_argument('--xpus', default='-1', 
                                help='-1 for no xpu, use comma for multiple xpus')
        ...

    def parse(self, args=''):
        ...
        opt.xpus_str = opt.xpus
        opt.xpus = [int(xpu) for xpu in opt.xpus.split(',')]
        opt.xpus = [i for i in range(len(opt.xpus))] if opt.xpus[0] >=0 else [-1] 
        ...
        if opt.debug > 0:
            ...
            opt.xpus = [opt.xpus[0]]
            ...
        ...
  • Set opt.device accordingly in CenterNet/src/lib/detectors/base_detector.py
class BaseDetector(object):
    def __init__(self, opt):
        if opt.xpus[0] >=0:
            opt.device = torch.device('xpu')
        elif opt.gpus[0] >= 0:
            opt.device = torch.device('cuda')
        else:
            opt.device = torch.device('cpu')
  • Replace all calls to torch.cuda.synchronize() as shown here, and add import intel_extension_for_pytorch as ipex in each modified Python file.
import intel_extension_for_pytorch as ipex
...
# torch.cuda.synchronize()
if self.opt.device.type == 'cuda':
    torch.cuda.synchronize()
elif self.opt.device.type == 'xpu':
    torch.xpu.synchronize()
...

Align Memory Format

The oneDNN library uses blocked memory layout[4] for weights by default to achieve better performance on Intel® Data Center GPU Flex Series and Intel® Arc™ GPU. While on Intel® Data Center GPU Max Series, plain data format[5] is used by default, which is NCHW in PyTorch. To get correct output on Intel® Data Center GPU Flex Series and Intel® Arc™ GPU, and to get better performance on Intel® Data Center GPU Max Series, we need to align memory format between layers implemented by oneDNN and customized operators by SYCL*.

  • Convert memory fomat to NHWC for input images and model at the begining of process function, use src/lib/detectors/ctdet.py as shown here:
def process(self, images, return_time=False):
  images = images.to(memory_format=torch.channels_last)
  self.model = self.model.to(memory_format=torch.channels_last)
  with torch.no_grad():
    output = self.model(images)[-1]
    ...
  ...
  • Convert input tensors to Contiguous memory format when calling customized operators, as shown using modulated_deformable_im2col_cuda as an example:
modulated_deformable_im2col_cuda(
  &getCurrentXPUQueue(),
  input.to(input.options(), true, false, at::MemoryFormat::Contiguous).data_ptr<scalar_t>(),
  offset.to(offset.options(), true, false, at::MemoryFormat::Contiguous).data_ptr<scalar_t>(),
  mask.to(mask.options(), true, false, at::MemoryFormat::Contiguous).data_ptr<scalar_t>(),
  batch, channels, height, width,
  height_out, width_out, kernel_h, kernel_w,
  pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
  deformable_group,
  columns.data_ptr<scalar_t>());
  • Replace all .view  with .reshape in CenterNet/src/lib/models/networks/DCNv2/src/sycl/ 

  • For example, in CenterNet/src/lib/models/networks/DCNv2/src/sycl/dcn_v2_cuda.dp.cpp:

...

// auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w});
auto weight_flat = weight.reshape({channels_out, channels * kernel_h * kernel_w});

...

// output = at::add(output, product.view({batch, channels_out, height_out, width_out}));
output = at::add(output, product.reshape({batch, channels_out, height_out, width_out}));

...

// auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w});
auto weight_flat = weight.reshape({channels_out, channels*kernel_h*kernel_w});

...

// auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out});
auto grad_output_n_flat = grad_output_n.reshape({channels_out, height_out*width_out});

...

// grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w}));
grad_weight = at::add(grad_weight, product.reshape({channels_out, channels, kernel_h, kernel_w}));

...

// auto ones_flat = ones.view({height_out*width_out});
auto ones_flat = ones.reshape({height_out*width_out});

...

Run Demo

  • Download models (for example: ctdetcocodla_2x) from model zoo to CenterNet/models
  • Run demo
cd CenterNet/src
python demo.py --load_model ../models/ctdet_coco_dla_2x.pth --demo ../images/34501842524_3c858b3080_k.jpg --xpus 0 --debug 1 ctdet
  • Result

    Result

Figure 5: Result

Conclusion

The Khronos SYCL* C++ standard is the open path for developing heterogeneous code that runs across multiple architectures. The Intel® DPC++ Compatibility Tool assists in migrating your existing CUDA* code to SYCL* code. You can use Intel® Extension for PyTorch* to build generated SYCL* code as a PyTorch extension and run on Intel GPUs easily.

Acknowledgement

We would like to thank Hui Wu from Intel® Extension for PyTorch* development team, and David Kinder for the article review.

Reference

[1] J. Dai, H. Qi, Y. Xiong, Y. Li, G. Zhang, H. Hu, and Y. Wei. Deformable convolutional networks. In ICCV, 2017.
https://github.com/CharlesShang/DCNv2
https://github.com/lucasjinreal/DCNv2_latest

[2] "Objects as Points"
https://arxiv.org/abs/1904.07850
https://github.com/xingyizhou/CenterNet

[3] Intel® DPC++ Compatibility Tool
https://www.intel.com/content/www/us/en/developer/tools/oneapi/dpc-compatibility-tool.html

[4] Intel® Extension for PyTorch*
https://github.com/intel/intel-extension-for-pytorch

[5] Blocked Layout in oneDNN
https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html#blocked-layout

[6] Plain Data Format
https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html#plain-data-formats

Notices & Disclaimers

© Intel Corporation. Intel, the Intel logo, and other Intel marks are trademarks of Intel Corporation or its subsidiaries. Other names and brands may be claimed as the property of others.