Contributor: Feng Ding
Introduction
This article aims to help developers who want to write custom op, debug and so on based on the source code of Intel® Extension for TensorFlow*, to sort out the code context, read code, and find related resources.
The Tensorflow* community proposed the PluggableDevice (link) architecture, which provides a plug-in mechanism that allows devices to be registered in TensorFlow without changing the TensorFlow code so that the accelerator and Tensorflow* can be seamlessly integrated. Based on that, Intel releases its high-performance plugin - Intel® Extension for TensorFlow* (source code), which allows TensorFlow* code to run on Intel XPU (GPU and CPU) freely.
The PluggableDevice mechanism has four main components: PluggableDevice type, Custom operations and kernels, Device execution and memory management, Custom graph optimization pass. More details here.
The code structure of Intel® Extension for TensorFlow* is shown in the figure below.
PluggableDevice
Tensorflow extends the device class hierarchy to add a standardized pluggable device named PluggableDevice which is built on top of StreamExecutor, and all new third-party devices who want to integrate with current TensorFlow stack only need to implement StreamExecutor C API.
Related Code:
itex/core/devices/xpu_device.cc SE_InitPlugin()
itex/core/devices/bfc_allocator.cc
tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc
when “import tensorflow”, Tensorflow loads library “tensorflow-plugins/libitex_gpu.so” or “tensorflow-plugins/libitex_cpu.so”, and invokes SE_InitPlugin(), DEVICE_XPU_NAME is “XPU” when creating session, “xpu_create_stream_executor” creates StreamExecutor which implements memory management,allocate/free/merge,memcpy htod/dtoh/dtod,stream(queue) management,event management timer and so on.
For more details, click here .
More DPC++/SYCL code,for example “aligned_alloc_device”, “aligned_alloc_host” are defined in intel/llvm
Pluggable Graph
TensorFlow provides plug-in mechanism with C API to register custom graph optimizers.
Related Code:
itex/core/graph/xpu_graph.cc TF_InitGraph() -> Optimizer_Optimize()
tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc
itex/core/utils/protobuf/config.proto
-
RunRemapper
Finding fused pattern and modify graph.
For example, “_ITEXFusedAddV2WithSoftmax” implements “Addv2+softmax”, in remapper.cc, “FindAddV2WithSoftmax()” finds the matched pattern,”AddFusedAddV2WithSoftmaxNode()” modifies graph.
The test code:test/tensorflow/python/grappler/addv2_with_softmax_pattern.py
-
RunAutoMixedPrecision
Following the certain algorithm steps, convert some nodes to FP16 or BF16 implementation, and insert cast node.
More details click here. -
RunOneDnnGraph
oneDNN graph optimizer,default is closed.
-
RunOneDnnLayout
GPU operators and fused operators rewrite and replace.
Related Code:itex/coregraph/xpu_optimizer.cc itex/core/graph/onednn_layout/onednn_layout.cc RunOneDnnLayout() -> CheckForNodeRewrite() -> GetRewriteInfo()
-
RunNativeLayout
itex/core/graph/xpu_optimizer.cc itex/core/graph/native_layout/native_layout.cc RunNativeLayout() -> CheckForNodeNativeFormat() -> GetNativeFormatInfo()
CPU operators and fused operators rewrite and replace.
For more details, click here.
Pluggable Kernel
Tensorflow provides plug-in mechanism with C API to register custom kernel and op implementations.
Related Code:
itex/core/kernels/xpu_kernel.cc TF_InitKernel
itex/core/utils/op_kernel.cc RegisterCPUKernels() RegisterGPUKernels()
tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc
For more details, click here.
Pluggable Profiler
Tensorflow provides plug-in mechanism with C API to implement and register pluggable profilers.
Related Code:
itex/core/profiler/gpu_profiler.cc TF_InitProfiler
tensorflow/core/common_runtime/pluggable_device/pluggable_device_plugin_init.cc
For more details, click here.
How to Build Itex from Source Code
Refer to https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/install/how_to_build.md
Note:
$ bazel build -c opt --config=gpu //itex/tools/pip_package:build_pip_package
Changing “-c opt” to “-c dbg”, the debug info will be built to binary, then we can use gdb to debug.
VERBOSE for Debug
-
Itex Verbose
$ export ITEX_VERBOSE=1
the value can be set 1, 2, 3, 4, level 4 will dump graph
Related Code: itex/core/graph/xpu_optimizer.cc DumpGraphDefToFile
For more details, click here.
-
oneDNN Verbose
$ export ONEDNN_VERBOSE=1
oneDNN verbose mode enables tracing execution of oneDNN primitives and collection of basic statistics like execution time and primitive parameters. When verbose mode is enabled oneDNN will print out information to `stdout`.
oneDNN version is set here.
How to Write Custom Op
-
Refer to ResizeBilinear op
itex/core/kernels/cpu/resize_bilinear_op.cc
itex/core/kernels/gpu/resize_bilinear_op.cc
itex/core/kernels/onednn/block/resize_bilinear_op.cc
itex/core/kernels/onednn/block/resize_op.hResizeBilinearOp cpu, gpu implementation and REGISTER_KERNEL_BUILDER itex/core/ops/nn_ops.cc
itex/core/ops/onednn/onednn_nn_ops.ccRegister OP input, attr and so on, Register_ITEXResizeBilinearOp
Register_OneDnnResizeBilinearOpitex/core/ops/op_init.cc
itex/core/ops/op_init.hAdd Register_ITEXResizeBilinearOp() or Register_OneDnnResizeBilinearOp() to RegisterOps()
itex/core/kernels/gpu/BUILD(for GPU)
itex/core/kernels/cpu/BUILD (for CPU)Add kernel implementation to build system itex/core/ops/onednn/BUILD
itex/core/ops/BUILDAdd OP to build system itex/core/graph/onednn_layout/onednn_layout.cc Map ResizeBilinear to oneDNN GPU implementation itex/core/graph/native_layout/native_layout.cc Map ResizeBilinear to oneDNN CPU implementation Build the whole code according to chapter 5, the above kernel implementation will be compiled into libitex_gpu.so or libitex_cpu.so
Tensorflow's standard resize_bilinear OP is mapped to an implementation of oneDNN in RunOneDnnLayout or to a CPU implementation in RunNativeLayout described in Section 2.
-
Test resize_bilinear
We can call tensorflow's standard tf.image.resize_bilinear interface, or call resize_bilinear through load_ops_library. The underlying call is the same implementation, you can turn on verbose to see the details.import tensorflow as tf import numpy as np from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library tf.compat.v1.disable_eager_execution() np.set_printoptions(precision=3) resize_shape = (10, 10) a = np.ones((1, 2, 2, 1), dtype=np.float32) a[0, 0, 0, 0] = 5.0 a[0, 1, 1, 0] = 5.0 b = tf.constant(a, dtype=tf.float32) c = tf.compat.v1.image.resize_bilinear(b, resize_shape) d = load_ops_library.resize_bilinear(b, resize_shape) with tf.compat.v1.Session() as sess: np_c = sess.run(c) print(np_c[0, :, :, 0]) np_d = sess.run(d) print(np_d[0, :, :, 0])
- Refer to ItexRnn op
itex/core/kernels/gpu/rnn_ops.cc itex/core/kernels/gpu/rnn_ops.h itex/core/kernels/gpu/rnn_ops_gpu.cc itex/core/kernels/gpu/rnn_ops_gpu.h Define RnnOp, RnnGradOp, and register “ItexRnn”, “ItexRnnGrad”, REGISTER_KERNEL_BUILDER(“ItexRnn”)
itex/core/ops/rnn_ops.cc Register ITEXRnnOP input, attr and so on itex/core/kernels/gpu/BUILD Add kernel gpu implementation to build system itex/core/ops/BUILD Add OP implementation to build system itex/core/ops/op_init.cc itex/core/ops/op_init.h Add Register_ITEXRnnOp(), Register_ITEXRnnGradOp() to RegisterOps()
itex/python/ops/ops_grad.py Register ItexRnnGrad itex/python/ops/recurrent.py Register ItexLSTM class and custom python API to custom itex package - Test itex_rnn
import tensorflow as tf import intel_extension_for_tensorflow as itex inputs = tf.random.normal([32, 10, 8]) lstm = itex.ops.ItexLSTM(4) output = lstm(inputs) print(output.shape) lstm = itex.ops.ItexLSTM(4, return_sequences=True, return_state=True) whole_seq_output, final_memory_state, final_carry_state = lstm(inputs) print(whole_seq_output.shape) print(final_memory_state.shape) print(final_carry_state.shape)
- Refer to _ITEXFusedAddV2WithSoftmax
itex/core/kernels/gpu/softmax_op.cc Define AddV2WithSoftmaxOp, and register “_ITEXFusedAddV2WithSoftmax” REGISTER_KERNEL_BUILDER(“_ITEXFusedAddV2WithSoftmax”)
itex/core/kernels/gpu/softmax_op_functor.h Kernel implementation AddV2WithSoftmaxFunctor itex/core/ops/nn_ops.cc Register input, attr and son on.
Register_ITEXFusedAddV2WithSoftmaxOpitex/core/kernels/gpu/BUILD Add kernel gpu implementation to build system itex/core/ops/BUILD Add OP implementation to build system itex/core/ops/op_init.cc
itex/core/ops/op_init.hAdd Register_ITEXFusedAddV2WithSoftmaxOp to RegisterOps itex/core/graph/remapper/remapper.cc Find Addv2 + softmax pattern, and modify graph
test/tensorflow/python/grappler/addv2_with_softmax_pattern.py Test case
More References
https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/guide/itex_ops.md
https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/guide/itex_fusion.md
https://github.com/tensorflow/community/blob/master/rfcs/20190814-kernel-and-op-registration.md
https://github.com/tensorflow/community/blob/master/rfcs/20210513-pluggable-profiler-for-tensorflow.md
https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/guide/aamp_tune.md
https://github.com/intel/llvm
https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/guide/itex_ops.md
https://github.com/intel/intel-extension-for-tensorflow/blob/main/docs/guide/itex_fusion.md