Gaudi processor logo

Detect Dynamic Shapes

Learn how to find dynamic data and operations (ops) in your models and the ways to help reduce these issues for better performance.

author-image

By

Dynamic Shapes and How to Detect Them

Dynamic shapes describe model behaviors that are based on how dynamic input data or operators affect the generation of variable output tensor shapes.

Usually dynamicity introduces recompilations, which slow down running the model. To optimize a model's speed:

  1. Identify whether the model has dynamic inputs or ops.
  2. If possible, mitigate the issues by following the steps in Handling Dynamic Shapes.

This article discusses some tools to detect dynamic inputs and ops.

Types of Dynamicity

Two main areas generate dynamic shapes:

  • Inputs: The result of varying input shapes during training, such as varying sentence lengths in language models or differing image resolutions in an image model.
  • Ops: Occur for ops whose output shape depends on the actual input data, rather than only the input shapes. In other words, ops that have noninferable output shapes for given input shapes.

Look for Dynamicity

To look for dynamicity in your model, check for general recompilations using the Dynamic Shape automated support feature in Intel Gaudi products:

  • Set the environment flag:
    PT_HPU_METRICS_FILE=/root/metricslog.json PT_HPU_METRICS_DUMP_TRIGGERS=process_exit,metric_change
    This gives a broad sense of the recompilations in the model. It creates a metricslog.json file that shows how often a graph_compilation is called. After a few steps, a reduction in recompilations is expected for static graphs.
  • If recompilations continue to exist, to enable automated Dynamic Shape control in Intel Gaudi products, set the following:
    PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES=1,
    To automatically manage dynamic shapes in model scripts, this variable can be set to enable the PyTorch* bridge in Intel Gaudi products and Graph Compiler. The graphs are automatically bucketed and padded into ranges to achieve a common size, reducing recompilations and improving performance when working with dynamic workloads.
  • If recompilations continue to exist or you encounter instability and want better performance, go to the next section.

Deeper Analysis of Model Data and Ops

The rest of this tutorial covers how to analyze your model with tools that allow you to pinpoint areas of dynamicity and make improvements.

Detect Dynamic Inputs

The data_dynamicity tool accepts a torch data loader and produces a report on the number of distinct input shapes to evaluate low versus high dynamicity in input datasets. For more strategies on mitigating high-input dynamicity by padding, see Text Datasets.

Image Datasets

Low-Input Dynamicity

The following example has a MNIST dataset with:

  • A batch size of 7
  • 2 input shapes, one for batch size = 7 and the other for batch size = 3 (MNIST has 60,000 training images, and 60,000%7=3, so the last size = 3).


This is considered to be low and an acceptable amount of dynamicity in the input.

from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader

# Creating a sample MNIST dataloader
mnist_ds = torchvision.datasets.MNIST('mnist', download=True, transform=torchvision.transforms.ToTensor())
mnist_dl = DataLoader(mnist_ds, batch_size=7, num_workers=2)

# Call the dataloader dynamicity tool on the dataloader
res = data_dynamicity(mnist_dl)

From this code, the tool provides the following output without dynamicity.

==============================================================================
|Shape                                               |Count                  |
==============================================================================
|((7, 1, 28, 28), (7,))                              |8571                   |
------------------------------------------------------------------------------
|((3, 1, 28, 28), (3,))                              |1                      |
------------------------------------------------------------------------------

Number of unique shapes:  2
There is a little dynamicity in input data shapes

High-Input Dynamicity

In comparison, for the next example, the Flowers 102 dataset has images of 29 different input shapes.

pip install scipy

from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
import torch

# Join a list of images/labels into a single batched tensor
# In this case we find the image with the largets dimensions in the batch,
# and then pad everything else to that size
def collate(batch):
   dim1 = min([k[0].shape[1] for k in batch])
   dim2 = min([k[0].shape[2] for k in batch])
   images = torch.stack([k[0][:,:dim1,:dim2] for k in batch])
   labels = torch.tensor([k[1] for k in batch])
   return (images,labels)

flowers_ds = torchvision.datasets.Flowers102('flowers', download=True, transform=torchvision.transforms.ToTensor())
flowers_dl = DataLoader(flowers_ds, batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(flowers_dl)
===================================================================================
|Shape                                                      |Count                |
===================================================================================
|((7, 3, 500, 500), (7,))                                   |111                  |
-----------------------------------------------------------------------------------
|((7, 3, 500, 667), (7,))                                   |5                    |
----------------------------------------------------------------------------------
|((7, 3, 500, 528), (7,))                                   |2                    |
-----------------------------------------------------------------------------------
|((7, 3, 500, 501), (7,))                                   |2                    |
----------------------------------------------------------------------------------
|((7, 3, 500, 542), (7,))                                   |2                    |
--------------------------------------------------------------------------------
|((7, 3, 500, 592), (7,))                                   |1                    |
----------------------------------------------------------------------------------
***
***
|((7, 3, 500, 549), (7,))                                   |1                    |
-----------------------------------------------------------------------------------
|((5, 3, 500, 500), (5,))                                   |1                    |
-----------------------------------------------------------------------------------
Number of unique shapes:  29
There is a lot of dynamicity in input data shapes

Depending on the use case, you can bucket images to certain fixed sizes or resize/crop them to a single shape. A center-crop solution is shown in the following example, which makes the Flowers 102 dataset more static.

from habana_frameworks.torch.utils.experimental import data_dynamicity
import torchvision
from torch.utils.data import DataLoader
import torch

def collate(batch):
   images = torch.stack([k[0] for k in batch])
   labels = torch.tensor([k[1] for k in batch])
   return (images,labels)

# Center crop to a fixed size, applied as a transform
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.CenterCrop((300,300))])
flowers_ds = torchvision.datasets.Flowers102('flowers', download=True, transform=transform)
flowers_dl = DataLoader(flowers_ds, batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(flowers_dl)
==============================================================================
|Shape                                                 |Count                |
==============================================================================
|((7, 3, 300, 300), (7,))                              |145                  |
------------------------------------------------------------------------------
|((5, 3, 300, 300), (5,))                              |1                    |
------------------------------------------------------------------------------
Number of unique shapes:  2
There is a little dynamicity in input data shapes

Text Datasets

Since sentence sizes vary, there is often high input dynamicity for text datasets. The following example has 443 different shapes for a SQUAD dataset when batching with batch size = 7. Each batch is padded to the largest sentence size:

pip install datasets
pip install transformers

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity

# Pad to max length sentence in each batch
def collate(batch):
    def pad(item, val, maxlen):
        return torch.tensor([i + [val]*(maxlen-len(i)) for i in item])
    token = [k['token_type_ids'] for k in batch]
    attention = [k['attention_mask'] for k in batch]
    inp = [k['input_ids'] for k in batch]
    token_lens = [len(i) for i in token]
    # Find the max length sentence in this batch
    max_len = max(token_lens)
    assert token_lens == [len(i) for i in attention] == [len(i) for i in inp]
    return {'token_type_ids': pad(token, 0, max_len), 'attention_mask': pad(attention, 0, max_len), 'input_ids': pad(inp, 0, max_len)}

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)

dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=collate)
res = data_dynamicity(dt)
=================================================================================================================
|Shape                                                                                                   |Count |
=================================================================================================================
|((-1023680607561683160, (7, 160)), (-4748259973688274144, (7, 160)), (-5213422677791015773, (7, 160)))  |114   |
-----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 145)), (-4748259973688274144, (7, 145)), (-5213422677791015773, (7, 145)))  |109   |
-----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 143)), (-4748259973688274144, (7, 143)), (-5213422677791015773, (7, 143)))  |108   |
------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 180)), (-4748259973688274144, (7, 180)), (-5213422677791015773, (7, 180)))  |107   |
----------------------------------------------------------------------------------------------------------------
***
***
|((-1023680607561683160, (7, 149)), (-4748259973688274144, (7, 149)), (-5213422677791015773, (7, 149)))  |99    |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 513)), (-4748259973688274144, (7, 513)), (-5213422677791015773, (7, 513)))  |1     |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 431)), (-4748259973688274144, (7, 431)), (-5213422677791015773, (7, 431)))  |1     |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 159)), (-4748259973688274144, (1, 159)), (-5213422677791015773, (1, 159)))  |1     |
----------------------------------------------------------------------------------------------------------------
Number of unique shapes:  443
There is a lot of dynamicity in input data shapes

A simple way to get static shapes is to pad the data to the longest sentence length. However, this is inefficient computationally because the compute effort on the padded sections (which are thrown away later) are wasted.

The next example shows the same SQUAD dataset padded to maximum sentence length, exhibiting low input dynamicity.

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity

# Pad to max sentence length in the whole dataset
def get_collate(max_sentence):
    def collate(batch):
        def pad(item, val):
            return torch.tensor([i + [val]*(max_sentence-len(i)) for i in item])
        token = [k['token_type_ids'] for k in batch]
        attention = [k['attention_mask'] for k in batch]
        inp = [k['input_ids'] for k in batch]
        return {'token_type_ids': pad(token, 0), 'attention_mask': pad(attention, 0), 'input_ids': pad(inp, 0)}
    return collate

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)
# Find max sentence length in the whole dataset
max_sentence = max([len(dt['input_ids']) for dt in tokenized_dataset['train']])
dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=get_collate(max_sentence))
res = data_dynamicity(dt)
==================================================================================================================
|Shape                                                                                                    |Count |
==================================================================================================================
|((-1023680607561683160, (7, 867)), (-4748259973688274144, (7, 867)), (-5213422677791015773, (7, 867)))   |12514 |
------------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 867)), (-4748259973688274144, (1, 867)), (-5213422677791015773, (1, 867)))   |1     |
------------------------------------------------------------------------------------------------------------------
Number of unique shapes:  2
There is a little dynamicity in input data shapes

By using bucketing, you can reduce compilations yet not waste computation by padding to the longest sentence. Select a hyperparameter (the number of buckets) and, in the dataset, use an algorithm to divide into buckets the range between the lengths of the shortest and the longest sentences. Then, for each batch, find the longest sentence and pad it to a bucket that is slightly larger than it.

For a case study using Wav2vec, see Case Study.

from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import torch
from habana_frameworks.torch.utils.experimental import data_dynamicity
import numpy as np

def get_buckets(sizes, num_buckets):
   buckets = np.unique(
      np.percentile(
            sizes,
            np.linspace(0, 100, num_buckets + 1),
            interpolation="lower",
      )[1:]
   )
   return buckets

# Find the largest sentence in the batch
# Then find the bucket just larger than it, and pad everything to that
def get_collate(buckets):
    def collate(batch):
        def pad(item, val):
            max_in_batch = max([len(i) for i in item])
            nearest_bucket = np.where(buckets>=max_in_batch)[0][0]
            return torch.tensor([i + [val]*(buckets[nearest_bucket]-len(i)) for i in item])
        token = [k['token_type_ids'] for k in batch]
        attention = [k['attention_mask'] for k in batch]
        inp = [k['input_ids'] for k in batch]
        return {'token_type_ids': pad(token, 0), 'attention_mask': pad(attention, 0), 'input_ids': pad(inp, 0)}
    return collate

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

squad_dataset = load_dataset('squad')
tokenized_dataset = squad_dataset.map(lambda x: tokenizer(x['context']), batched=True)
buckets = get_buckets([len(dt['input_ids']) for dt in tokenized_dataset['train']], 5)
dt = DataLoader(tokenized_dataset['train'], batch_size=7, num_workers=2, collate_fn=get_collate(buckets))
res = data_dynamicity(dt)
===============================================================================================================
|Shape                                                                                                |Count |
================================================================================================================
|((-1023680607561683160, (7, 867)), (-4748259973688274144, (7, 867)), (-5213422677791015773, (7, 867))) |4543  |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 207)), (-4748259973688274144, (7, 207)), (-5213422677791015773, (7, 207))) |3350  |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 164)), (-4748259973688274144, (7, 164)), (-5213422677791015773, (7, 164))) |2277  |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 138)), (-4748259973688274144, (7, 138)), (-5213422677791015773, (7, 138))) |1388  |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (7, 114)), (-4748259973688274144, (7, 114)), (-5213422677791015773, (7, 114))) |956   |
----------------------------------------------------------------------------------------------------------------
|((-1023680607561683160, (1, 164)), (-4748259973688274144, (1, 164)), (-5213422677791015773, (1, 164))) |1     |
----------------------------------------------------------------------------------------------------------------
Number of unique shapes:  6
There is some dynamicity in input data shapes

The model now has only six input shapes, where the sentence lengths have been separated into buckets and then used the smallest amount of padding possible to fill each bucket.

Detect Dynamic Ops

Now that you know how to detect dynamic inputs, you can attempt to detect dynamic ops in models. Dynamic ops are operations whose output shapes cannot be predicted just from knowing the input shapes.

Here is an example: A simple toy model is run for five steps. The input shape changes at the fourth step and usually results with a recompilation. However, the model itself has dynamic ops, so the tool identifies the module, which might be dynamic.

The following code examples can be copied into a Python* file (dyn_ops.py) and then run in the terminal window.

from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
import torch

class InnerNet(torch.nn.Module):
   def __init__(self):
      super(InnerNet, self).__init__()
      self.conv = torch.nn.Conv2d(1, 8, 3, 3)

   def forward(self, x):
      x = torch.flatten(self.conv(x), 1)
      x = x[x>0] # This is dynamic
      return x.sum()

net = torch.nn.Sequential(torch.nn.ReLU(), InnerNet()).to('hpu')
net = detect_recompilation_auto_model(net) # wrap model in dynamic op detection tool

for bs in [20,20,30,30]: #Input shape changes at 3rd step
   inp = torch.rand(bs, 1, 50, 50).to('hpu')
   print(net(inp))  
net.analyse_dynamicity() # Call this after a few steps to generate the dynamicity report

The detect_recompilation_auto_model tool outputs two tables and corresponding .csv files.

The first table shows what happens at each step, while the second table shows which module or submodule recompiled the most times. Let's analyze the first table:

 

Step

Recompiling Modules

New In

New Out

Class

Location

Comment

0

Net/0

True

True

torch.nn.modules.activation.ReLU

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py

Recompiled due to new input shape

0

Net/1/conv

True

True

torch.nn.modules.conv.Conv2d

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py

Recompiled due to new input shape

0

Net/1

True

True

__main__.InnerNet

dyn_ops.py

Recompiled due to new input shape

0

Net

True

True

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Recompiled due to new input shape

1

Net/1

False

False

__main__.InnerNet

dyn_ops.py

Already processed input shape still recompiled. The issue may be dynamic ops.

1

Net

False

False

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Already processed input shape still recompiled. The issue could be due to dyn ops or a dynamic child.

2

Net/0

True

True

torch.nn.modules.activation.ReLU

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py

Recompiled due to new input shape

2

Net/1/conv

True

True

torch.nn.modules.conv.Conv2d

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py

Recompiled due to new input shape

2

Net/1

True

False

__main__.InnerNet

dyn_ops.py

Recompiled due to new input shape

2

Net

True

False

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Recompiled due to new input shape

3

Net/1

False

False

__main__.InnerNet

dyn_ops.py

Already processed input shape still recompiled. The issue may be dynamic ops.

3

Net

False

False

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Already processed input shape still recompiled. The issue could be due to dyn ops or a dynamic child.

 

Step 0: Shows all modules to recompile (since it is the first step).

Step 1: Shows an InnerNet and Net recompile. The Comment column shows:

  • InnerNet might be dynamic because it recompiled even without dynamic children modules.
  • Net might not be dynamic. It might have recompiled because its child (InnerNet) recompiled.

Step 2: A new input shape is seen, so every module recompiles as expected (shown in the Comment column).

Step 3: Indicates InnerNet has dynamic ops. As a result, possible outputs from the tool (as shown in the Comment column) are:

  1. Recompiled due to new input shape.
  2. An already-processed input shape recompiled and has a new output shape that may be a result of dynamic ops or a dynamic child.
  3. An already-processed input shape is recompiled and maybe the issue is related dynamic ops.
  4. An already-processed input shape recompiled and has a new output shape, possibly as a result of dynamic ops.

Note This tool takes a long time to run, so we recommend doing the following:

  • Run for a short number of steps.
  • Run on a single Intel Gaudi card (without distributed).
  • While the tool can detect and ignore recompilation due to inputs, it is recommended to pass in the same shape inputs where possible to save time running the tool.

With static inputs, the tool can focus only on finding dynamic ops.

In the next example, we replace the dynamic portion with a static equivalent. On running the detect_recompilation_auto_model tool, only dynamicity from inputs appears.

from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
import torch


class InnerNet(torch.nn.Module):
   def __init__(self):
      super(InnerNet, self).__init__()
      self.conv = torch.nn.Conv2d(1, 8, 3, 3)

   def forward(self, x):
      x = torch.flatten(self.conv(x), 1)
      #x = x[x>0] # This is dynamic, replacing in next line with static implementation
      x = torch.where(x>0, x, torch.zeros_like(x))
      return x.sum()

net = torch.nn.Sequential(torch.nn.ReLU(), InnerNet()).to('hpu')
net = detect_recompilation_auto_model(net)

for bs in [20,20,30,30]: #Input shape changes at 4th step
   inp = torch.rand(bs, 1, 50, 50).to('hpu')
   print(net(inp))
net.analyse_dynamicity() # Call this after a few steps to generate the dynamicity report

 

Step

Recompiling Modules

New In

New Out

Class

Location

Comment

0

Net/0

True

True

torch.nn.modules.activation.ReLU

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py

Recompiled due to new input shape

0

Net/1/conv

True

True

torch.nn.modules.conv.Conv2d

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py

Recompiled due to new input shape

0

Net/1

True

True

__main__.InnerNet

dyn_ops_static.py

Recompiled due to new input shape

0

Net

True

True

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Recompiled due to new input shape

2

Net/0

True

True

torch.nn.modules.activation.ReLU

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/activation.py

Recompiled due to new input shape

2

Net/1/conv

True

True

torch.nn.modules.conv.Conv2d

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/conv.py

Recompiled due to new input shape

2

Net/1

True

False

__main__.InnerNet

dyn_ops_static.py

Recompiled due to new input shape

2

Net

True

False

torch.nn.modules.container.Sequential

/usr/local/lib/python3.8/dist-packages/torch/nn/modules/container.py

Recompiled due to new input shape

 

Detect Dynamic Sections

This example focuses on detecting dynamic sections in a real model, Faster RCNN.

wget https://ultralytics.com/assets/coco128.zip
unzip coco128.zip

import torchvision, os
from PIL import Image
import torchvision.transforms as T

import habana_frameworks.torch.core as htcore
device = 'hpu'

#load model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval() # set to evaluation mode
model = model.to(device) # move model to device

from habana_frameworks.torch.utils.experimental import detect_recompilation_auto_model
model = detect_recompilation_auto_model(model, waittime=0.3)

for idx, k in enumerate(os.listdir('coco128/images/train2017/')):
    img = Image.open('coco128/images/train2017/' + k).resize((600,600))
    img = T.ToTensor()(img).to(device)
    print('inp shape:', img.shape)
    pred = model([img])
    htcore.mark_step()
    if idx == 6: # just running first few images
        break
    print('done img', idx)
model.analyse_dynamicity()

 

The outputs show the following:

Step

Recompiling Modules

New In

New Out

Class

Location

Comment

1

Net/roi_heads/box_roi_pool

False

False

torchvision.ops.poolers.MultiScaleRoIAlign

/usr/local/lib/python3.8/dist-packages/torchvision/ops/poolers.py

The already processed input shape was still recompiled. The issue may be dynamic ops.

1

Net/roi_heads

False

True

torchvision.models.detection.roi_heads.RoIHeads

/usr/local/lib/python3.8/dist-packages/torchvision/models/detection/roi_heads.py

Already processed input shape still recompiled and has new output shape. The issue may be dynamic ops.

 

This information relays that the MultiScaleRoIAlign and RoIHeads classes have some dynamic ops.

MultiScaleRoIAlign

RoIHeads

You can rewrite these sections as static or move the operation to a CPU. For strategies, see Mitigation Techniques for Dynamic Ops.

Copyright© 2023 Habana Labs, Ltd. an Intel Company.

Licensed under the Apache* License, Version 2.0 (the “License”)

You may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.