Boost Performance for LLMs Using PyTorch* on Intel® Xeon® Scalable Processors: Part 2

Get the Latest on All Things CODE

author-image

By

Previously, we shared how to boost large language models (LLM) with PyTorch* on Intel® Xeon® processors, which provides the best practices for general inference scenarios with publicly available LLMs. This article guides you more advanced scenarios such as how to achieve peak performance on a customized LLM or your own private LLMs. Customized LLMs usually have specific model structures that cannot follow general practices for optimized performance. This article describes a fine-grained optimization to overcome this challenge.  

Read Part 1

A New Feature Supports Customizing LLMs with Private Structures 

In the past year, LLMs have flourished with many open models contributing to the community. While researchers are building their own LLMs from transformer blocks with variants in the implementation details, models with private structures are required to secure competitive business advantages. To help with these scenarios and improve productivity, Intel® Extension for PyTorch* v2.3 provides a module-level LLM optimization API. This new feature provides module-level optimizations for commonly used LLM modules and functionalities. LLM creators can adopt this new API set to replace those basic blocks with optimized alternatives in models to obtain peak performance.  

There are three categories of module-level LLM optimization APIs: 

Category 1: Linear Post-op APIs 

Linear+post operations patterns commonly exist in transformer blocks in LLM models, such as the Linear+activation function inside MLP and Feedforward layers. The following APIs provide the fusion optimization to reduce the post operations and fuse them into the linear computation to assist with performance speedup.

# using module init and forward 
ipex.llm.modules.linearMul 
ipex.llm.modules.linearGelu 
ipex.llm.modules.linearNewGelu 
ipex.llm.modules.linearAdd 
ipex.llm.modules.linearAddAdd 
ipex.llm.modules.linearSilu 
ipex.llm.modules.linearSiluMul 
ipex.llm.modules.linear2SiluMul 
ipex.llm.modules.linearRelu 

Category 2: Attention-Related APIs 

Attention layers are the most important part of LLM models, where the mha and key_value caches are computed to support LLM generation tasks. The following APIs provide more performing Attention layers for customizing LLMs, such as IndirectAccessKVCacheAttention, RotaryEmbedding, normalization, VarlenAttention, and PagedAttention

# using module init and forward 
ipex.llm.modules.RotaryEmbedding 
ipex.llm.modules.RMSNorm 
ipex.llm.modules.FastLayerNorm 
ipex.llm.modules.VarlenAttention 
ipex.llm.modules.PagedAttention 
ipex.llm.modules.IndirectAccessKVCacheAttention 
 
# using as functions 
ipex.llm.functional.rotary_embedding 
ipex.llm.functional.rms_norm 
ipex.llm.functional.fast_layer_norm 
ipex.llm.functional.indirect_access_kv_cache_attention 
ipex.llm.functional.varlen_attention 

Category 3: Generation-Related APIs 

Hugging Face* provides different generation functions as APIs to support different LLM decoding strategies. The following APIs provide support for some of those generation functions with optimizations like prompt sharing.
 

# using for optimizing huggingface generation APIs with prompt sharing 
ipex.llm.generation.hf_beam_sample 
ipex.llm.generation.hf_beam_search 
ipex.llm.generation.hf_greedy_search 
ipex.llm.generation.hf_sample 

 

How to Apply the Optimized Software Stack for an LLM 

As a demonstration for applying Intel Extension for PyTorch module-level LLM optimization APIs, we used Llama 3 from HuggingFace/Transformers. There are four layers of Llama 3 (also for most LLMs) that need to be modified: LlamaAttention, LlamaMLP, LlamaDecoderLayer, and LlamaModel. Functions and modules that can be optimized are: 

  • Fusion of linear2SiluMul at LlamaMLP layer 
  • Fusion of linearAdd from LlamaAttention layer and LlamaMLP layer with a residual add from the LlamaDecoderLayer layer 
  • Fusion of RotaryEmbedding and IndirectAccessKVCacheAttention at the LlamaAttention layer 
  • Fusion of RMSNorm at the LlamaDecoderLayer and lamaModel layers 

Besides the previous optimizations, for the LlamaForCausalLM layer, we must rewrite its functions of prepare_inputs_for_generation and _reorder_cache to work with IndirectAccessKVCacheAttention, as shown in this example

The following sections give more details. 

Optimization for LlamaAttention

class LlamaAttention(nn.Module): 
    def __init__(……): 
        # default attributes init and naming are the same as transformers design. 
        # changes to init IndirectAccessKVCacheAttention and RotaryEmbedding from ipex.llm layers 
        self.IPEXAttention = ipex.llm.modules.IndirectAccessKVCacheAttention( 
            self.max_position_embeddings 
        ) 
        self.ipex_rotary_emb = ipex.llm.modules.RotaryEmbedding( 
            self.max_position_embeddings, 
            self.head_dim, 
            self.rope_theta, 
            self.config.architectures[0], 
        ) 
  def forward(……): 
        # QKV projections, reshape and get kv_seq_len are skipped 
        # changes to apply RotaryEmbedding  
        key_states = self.ipex_rotary_emb( 
            key_states, 
            position_ids, 
            self.num_key_value_heads, 
            self.head_dim, 
            self.head_dim // 2, 
            self.head_dim, 
            kv_seq_len, 
        ) 
        query_states = self.ipex_rotary_emb( 
            query_states, 
            position_ids, 
            self.num_heads, 
            self.head_dim, 
            self.head_dim // 2, 
            self.head_dim, 
            kv_seq_len, 
        ) 
       # changes to apply IndirectAccessKVCacheAttention  
        (attn_output, attn_weights, past_key_value) = self.IPEXAttention ( 
            query_states, 
            key_states, 
            value_states, 
            math.sqrt(self.head_dim), 
            past_key_value, 
            None, 
            attention_mask, 
        ) 
        # self.o_proj is moved to LlamaDecoderLayer to enable linear+add fusion 
        return attn_output, attn_weights, past_key_value 

Optimization for LlamaMLP and LlamaDecoderLayer

class LlamaMLP(nn.Module): 
    def __init__(……): 
	      # default attributes init and naming are the same as transformers design. 
    def forward(……): 
        # changes to apply Linear2SiluMul fusion from ipex.llm layers   
        if not hasattr(self, "ipex_fusion"): 
            self.ipex_fusion = ipex.llm.modules.Linear2SiluMul( 
                self.gate_proj, self.up_proj 
            ) 
        return self.ipex_fusion(x) 
        # self.down_proj is move to LlamaDecoderLayer to enable linear+add fusion 
 
class LlamaDecoderLayer(nn.Module): 
	  def __init__(……..): 
        # default attributes init and naming are the same as transformers design.        
        # changes to apply RMSNorm from ipex.llm layers   
        self.input_layernorm = ipex.llm.modules.RMSNorm( 
             config.hidden_size, eps=config.rms_norm_eps 
         ) 
         self.post_attention_layernorm = ipex.llm.modules.RMSNorm( 
             config.hidden_size, eps=config.rms_norm_eps 
         ) 
	def forward(……): 
	      # get residual, run norm layer, run attention layer are skipped 
	 
        #  changes to apply LinearAdd fusion from ipex.llm layers   
        if not hasattr(self, "ipex_fusion_mha_linear_add"): 
            self. ipex_fusion_mha_linear_add = ipex.llm.modules.LinearAdd(self.self_attn.o_proj) 
        hidden_states = self. ipex_fusion_mha_linear_add(hidden_states, residual) 
 
	      # get residual, run norm layer, run mlp layer are skipped 
	 
        # changes to apply LinearAdd fusion from ipex.llm layers 
        if not hasattr(self, " ipex_fusion_mlp_linear_add "): 
            self.ipex_fusion_mlp_linear_add = ipex.llm.modules.LinearAdd(self.mlp.down_proj)    
        hidden_states = self.ipex_fusion_mlp_linear_add(hidden_states, residual) 
 
        outputs = (hidden_states,) 
        outputs += (present_key_value,) 
        return outputs 

Optimization for LlamaModel

class LlamaModel(PreTrainedModel): 
    def __init__(…….): 
        # default attributes init and naming are the same as transformers design.        
        # changes to apply RMSNorm from ipex.llm layers   
        self.norm = ipex.llm.modules.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) 
     def forward(……): # no changes from transformers design 

To apply optimizations when calling model.generate() from the Hugging Face text generation API:

# Taking beam search as example here 
hf_beam_search = ipex.llm.generation.hf_beam_search.__get__(model, model.__class__)  
setattr(model, "beam_search", hf_beam_search)  

Full Optimization Application Case and Related Code

Benchmark Results Demonstrate LLM Performance 

The following are benchmark results based on the previous optimizations using a single 4th gen Intel® Xeon® Scalable processor. The benchmarks compare the performance before and after applying this module-level optimization API from Intel Extension for PyTorch. The benchmarking is performed with an input token length from 1K to 8K to simulate different use scenarios and is run under the bfloat16 datatype with Llama 3 as the example model. 

As shown in figures 1 and 2, module-level LLM optimization API can produce approximately a 5x performance boost across different input token sizes, while for next token latency, the benefit grows when size increases from 3x on a 1K token input to more than 5x on an 8K token input. 

Figure 1. First token latency improvement with a module-level LLM optimization API 

Figure 2. Next token latency improvement with a module-level LLM optimization API

Conclusion

The new module-level LLM optimization API from Intel Extension for PyTorch provides good LLM inference performance for customized or private LLMs. Furthermore, this new API set can be applied to LLM-serving frameworks like vLLM and text generation inference (TGI), which also have their own customized LLM models. To provide good LLM performance on Intel Xeon processors, work for the adoption of this API set for these frameworks is ongoing. To adopt these optimizations and improve the performance for LLM models, we are also upstreaming these optimizations into PyTorch and a related ecosystem project.  

We encourage you to also check out and incorporate Intel’s other AI and machine learning framework optimizations and end-to-end portfolio of tools into your AI workflow. Learn about the unified, open, standards-based oneAPI programming model that forms the foundation of Intel® AI Portfolio to help you prepare, build, deploy, and scale your AI solutions.

Framework Optimizations

AI Development Software

Product and Performance Information  

Performance benchmarking configuration: Amazon Web Services (AWS)* m7i.metal-48xl, Intel® Xeon® Platinum 8488C CPU,  96 cores, 192 threads, microcode 0x2b000590; Ubuntu* 22.04.4 LTS, 6.5.0-1016-aws; PyTorch v2.3.0; Intel Extension for PyTorch v2.3.0; DeepSpeed v0.14.0; Model: Llama 3 8B; Token Length: 128/1024/2048/8064 (in), 128 (out); Tested on one socket, BS=1, beam=4, precision: bfloat16. Tested by Intel on April 30, 2024.