TorchServe Inference Server with Gaudi

This document provides instructions on deploying PyTorch models using TorchServe with Intel® Gaudi® 2 AI accelerator. TorchServe is a flexible and easy-to-use tool for serving and scaling PyTorch models in production. For more details, refer to PyTorch’s TorchServe documentation.

Installation

  1. Install the Intel Gaudi software as described in the Installation Guide. The supported versions are listed in the Support Matrix.

  2. Clone the below repository and install the dependencies. Make sure to include --skip_torch_install flag to avoid overriding the Intel Gaudi PyTorch package. Then, install the torch-model-archiver and torch-workflow-archiver:

    git clone https://github.com/pytorch/serve.git
    cd serve
    python ./ts_scripts/install_dependencies.py --skip_torch_install
    pip install torch-model-archiver torch-workflow-archiver
    
  3. Install torchserve:

    pip install torchserve
    

    Or, build from source:

    python ./ts_scripts/install_dependencies.py --skip_torch_install --environment=dev
    python ./ts_scripts/install_from_src.py
    
  4. Deploy a model using TorchServe. See ResNet50 model example and BERT model example with custom handler.

ResNet50 Model Example

This section outlines how to deploy a ResNet50 model using TorchServe, based on this tutorial. This example uses a default handler for image classification and an additional handler to set the device to hpu - see image_classifier.py.

  1. Configure torch.compile. In this example, the following config is provided in model-config.yaml file:

    echo "minWorkers: 1
    maxWorkers: 1
    pt2:
    compile:
        enable: True
        backend: hpu_backend" > model-config.yaml
    

    Note

    • The number of maxWorkers you deploy should be equal to or smaller than the number of cards you have in your server or container

    • Batch inference can also be configured in the model-config.yaml file using batchSize and maxBatchDelay parameters. The backend will wait until batchSize requests are aggregated, or a maximum of maxBatchDelay milliseconds. You can read more about batch inference in TorchServe here.

  2. Download the pre-trained model and prepare the model archive:

    wget https://download.pytorch.org/models/resnet50-11ad3fa6.pth
    mkdir model_store
    PT_HPU_LAZY_MODE=0 torch-model-archiver --model-name resnet-50 --version 1.0 --model-file model.py \
        --serialized-file resnet50-11ad3fa6.pth --export-path model_store \
        --extra-files ../../image_classifier/index_to_name.json --handler hpu_image_classifier.py \
        --config-file model-config.yaml
    

    Note

    PT_HPU_LAZY_MODE=0 disables lazy mode. Gaudi’s integration with PyTorch supports three modes of operation: eager, lazy and compile (beta state). Currently lazy is the default, disabling lazy mode via the PT_HPU_LAZY_MODE variable indicates eager mode. However, to enable compile mode you need eager mode and explicit model compilation in the script, which is done by adding the following line in the model-config.yaml file:

    compile:
        enable: True
        backend: hpu_backend"
    

    See Runtime Environment Variables for more details.

  3. Start the TorchServe server using the following command:

    PT_HPU_LAZY_MODE=0 torchserve --start --ncs  --model-store model_store --models resnet-50.mar --disable-token-auth --enable-model-api
    

    Below is the console output which confirms that the server on HPU has been started with information that the model has been run in compile mode:

    2024-06-25T14:21:09,470 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG - ============================= HABANA PT BRIDGE CONFIGURATION ===========================
    2024-06-25T14:21:09,472 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_HPU_LAZY_MODE = 0
    2024-06-25T14:21:09,472 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_RECIPE_CACHE_PATH =
    2024-06-25T14:21:09,472 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_CACHE_FOLDER_DELETE = 0
    2024-06-25T14:21:09,472 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_HPU_RECIPE_CACHE_CONFIG =
    2024-06-25T14:21:09,473 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_HPU_MAX_COMPOUND_OP_SIZE = 9223372036854775807
    2024-06-25T14:21:09,473 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_HPU_LAZY_ACC_PAR_MODE = 1
    2024-06-25T14:21:09,474 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG -  PT_HPU_ENABLE_REFINE_DYNAMIC_SHAPES = 0
    2024-06-25T14:21:09,474 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG - ---------------------------: System Configuration :---------------------------
    2024-06-25T14:21:09,474 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG - Num CPU Cores : 16
    2024-06-25T14:21:09,474 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG - CPU RAM       : 61711860 KB
    2024-06-25T14:21:09,475 [WARN ] W-9000-resnet-50_1.0-stderr MODEL_LOG - ------------------------------------------------------------------------------
    2024-06-25T14:21:09,505 [INFO ] W-9000-resnet-50_1.0-stdout MODEL_LOG - Compiled model with backend hpu_backend
    

    Note

    --disable-token-auth disables token authorization. This option is used here only for example purposes. Please refer to the TorchServe documentation, which describes the process of serving the model using tokens.

  4. Run inference. torch.compile requires a warm-up phase to reach optimal performance. Ensure you run at least as many inferences as the maxWorkers specified before measuring performance:

    # Open a new terminal
    cd  examples/pt2/torch_compile_hpu
    curl http://127.0.0.1:8080/predictions/resnet-50 -T ../../image_classifier/kitten.jpg
    

    The expected output is a JSON-formatted classification file with probabilities. For example:

    {
    "tabby": 0.2724992632865906,
    "tiger_cat": 0.1374046504497528,
    "Egyptian_cat": 0.046274710446596146,
    "lynx": 0.003206699388101697,
    "lens_cap": 0.002257900545373559
    }
    
  5. Stop TorchServe:

    torchserve --stop
    

BERT Model Example with Custom Handler

To deploy a BERT model for sequence classification using TorchServe, a custom handler that demonstrates how to perform the task of sequence classification is required. Unlike using a default handler, creating a custom handler allows you to tailor the model’s behavior to your specific needs.

  1. Save your fine-tuned BERT model with Hugging Face:

    1. Install the Hugging Face transformers library:

      pip install transformers
      
    2. Get a standard pre-trained BERT model and tokenizer from Hugging Face:

      from transformers import AutoTokenizer, AutoModelForSequenceClassification
      model_name = "bert-large-uncased"
      tokenizer = AutoTokenizer.from_pretrained(model_name)
      model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
      
    3. Fine-tune your model on a task-specific dataset since the model’s last layer will be randomly initialized. Once your BERT model is fine-tuned, save it together with the tokenizer:

      save_directory = './saved_model'
      tokenizer.save_pretrained(save_directory)
      model.save_pretrained(save_directory)
      
    4. Add your label mappings to index_to_name.json file:

      echo '{"0":"Not Accepted","1":"Accepted"}' > index_to_name.json
      
  2. Prepare a custom TorchServe handler. Due to the lack of native support for Hugging Face models in standard TorchServe handlers, a custom handler has to be utilized. The approach is based on this tutorial, which shows the deployment of transformer models for various other tasks. This custom handler implements the standard TorchServe interface. For more details on creating custom TorchServe handlers, refer to this guide:

    from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
    import habana_frameworks.torch.core as htcore
    import torch
    import torch.nn.functional as F
    from ts.torch_handler.base_handler import BaseHandler
    from ts.handler_utils.timer import timed
    from ts.utils.util import load_label_mapping, map_class_to_label
    import sys
    import os
    
    class HPUBertHander(BaseHandler):
        def initialize(self, ctx):
            model_dir = ctx.system_properties.get('model_dir')
            self.device = torch.device('hpu')
            self.config = AutoConfig.from_pretrained(model_dir)
            model = AutoModelForSequenceClassification.from_pretrained(
                model_dir, config=self.config).to(self.device)
            self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
            model.eval()
            self.model = torch.compile(model, backend='hpu_backend')
            mapping_file_path = os.path.join(model_dir, 'index_to_name.json')
            self.mapping = load_label_mapping(mapping_file_path)
            self.initialized = True
    
        @timed
        def preprocess(self, data):
            texts = [req.get('data') or req.get('body') for req in requests]
            texts = [text.decode('utf-8') if isinstance(text, (bytes, bytearray))
                    else text for text in texts]
            inputs = self.tokenizer(texts, return_tensors='pt')
            return inputs.to(self.device)
    
        @timed
        def inference(self, inputs):
            with torch.no_grad():
                outputs = self.model(**inputs)
            return outputs.logits
    
        @timed
        def postprocess(self, inference_output):
            probs = F.softmax(inference_output, dim=1).tolist()
            return map_class_to_label(probs, self.mapping)
    

    During initialization, the supplied tokenizer and pre-trained model are loaded. The model’s performance is optimized with torch.compile. Upon receiving a request, the handler processes an individual input sequence by decoding, applying appropriate tokenization, and executing inference.

    In the initialize function, the handler loads the model and tokenizer. The pre-trained config, model and tokenizer are retrieved from the archive:

    self.config = AutoConfig.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_dir, config=self.config).to(self.device)
    self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
    

    The model is also moved to hpu device, as set in self.device = torch.device('hpu'). Compilation is requested during the first inference request: self.model = torch.compile(model, backend='hpu_backend') to improve performance on subsequent requests. The mapping between returned indexes and their labels is loaded to help with interpreting inference results:

    mapping_file_path = os.path.join(model_dir, 'index_to_name.json')
    self.mapping = load_label_mapping(mapping_file_path)
    

    After the inference request is supplied to the handler, it is retrieved and decoded inside the preprocess function:

    def preprocess(self, requests):
        texts = [req.get('data') or req.get('body') for req in requests]
        texts = [text.decode('utf-8') if isinstance(text, (bytes, bytearray))
                else text for text in texts]
    

    The text is transformed with the tokenizer, inputs = self.tokenizer(text, return_tensors='pt'), and moved to hpu with return inputs.to(self.device). The inference function runs input data through the pre-trained model:

    def inference(self, inputs):
        with torch.no_grad():
            outputs = self.model(**inputs)
        return outputs.logits
    

    Finally, inside the postprocess function, raw model outputs are converted into probabilities and mapped to human-readable labels using a predefined mapping:

    def postprocess(self, inference_output):
        probs = F.softmax(inference_output, dim=1).tolist()
        return map_class_to_label(probs, self.mapping)
    
  3. Create BERT model archive. Aggregate the model and tokenizer configuration, associated parameters, and the provided handler to generate a consolidated model archive suitable for TorchServe deployment:

    PT_HPU_LAZY_MODE=0 torch-model-archiver --model-name bert --serialized-file saved_model/model.safetensors \
    --export-path model_store --version 1.0 --handler hpu_bert_handler.py \
    --extra-files "saved_model/config.json,index_to_name.json,saved_model/vocab.txt,saved_model/tokenizer.json,saved_model/tokenizer_config.json,saved_model/special_tokens_map.json" \
    --config-file model-config.yaml
    
  4. Start TorchServe with BERT model:

    PT_HPU_LAZY_MODE=0 torchserve --start --ncs --model-store model_store --models bert.mar --disable-token-auth --enable-model-api
    
  5. Run BERT inference by creating the file containing the sequence to run inference on:

    echo "The quick brown fox jumps over the lazy dog." > sample_text.txt
    
  6. Request inference on the sample:

    curl http://127.0.0.1:8080/predictions/bert -T sample_text.txt
    

    The obtained result contains the probabilities of class associations:

    {
        "Not Accepted": 0.6127049922943115,
        "Accepted": 0.3872949182987213
    }
    
  1. Stop TorchServe:

    torchserve --stop