Run Inference Using UINT4

This guide provides the steps required to enable UINT4 inference on your Intel® Gaudi® 2 AI accelerator. When running inference on large language models (LLMs), high memory usage is often the bottleneck. Therefore, using UINT4 data type for inference on large language models halves the required memory bandwidth compared to running inference in FP8.

Note

The following is currently supported:

  • GPTQ - Weight-Only-Quantization (WOQ) method.

  • nn.Linear module.

  • Single device only.

  • Lazy mode only (default).

  • The pre-quantized model should be in BF16 only.

  • Tested on Hugging Face Optimum for Intel Gaudi models only.

Prerequisites

  • The INC package, neural_compressor.torch.quantization, is available in the Intel Gaudi PyTorch package or Docker as detailed in the Installation Guide.

  • Intel Gaudi utilizes INC API to load models with 4bit checkpoints and adjust to run on Gaudi 2. INC supports models that were quantized to 4bit using Weight-Only-Quantization (WOQ).

Quantizing PyTorch Models to UINT4

Quantize the model with the run_clm_no_trainer.py script provided in Neural Compressor GitHub repo for GPTQ quantization:

python -u run_clm_no_trainer.py \
        --model <model_name_or_path> \
        --dataset <DATASET_NAME> \
        --quantize \
        --output_dir <tuned_checkpoint> \
        --tasks "lambada_openai" \
        --batch_size <batch_size> \
        --woq_algo GPTQ \
        --woq_bits 4 \
        --woq_group_size 128 \
        --woq_scheme asym \
        --woq_use_mse_search \
        --gptq_use_max_length

Note

  • Typical LLMs such as meta-llama/Llama-2-7b-hf, EleutherAI/gpt-j-6B, and facebook/opt-125m have been validated with this script.

  • For more information on the GPTQ and WOQ config flags, refer to this code.

Loading a WOQ Checkpoint Saved by INC

You can load the checkpoint created in the previous section or an existing checkpoint using the below steps. See the LlaMA 2 7B model for an example model using UINT4:

  1. Import habana_frameworks.torch.core:

    import habana_frameworks.torch.core as htcore
    
  2. Call the INC load API and target the Gaudi device:

    from neural_compressor.torch.quantization import load
    from transformers import AutoModelForCausalLM
    org_model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        **model_kwargs,
    )
    
    model = load(
        model_name_or_path=args.quantized_inc_model_path,
        format="default",
        device="hpu",
        original_model=org_model,
        **model_kwargs,
    )
    # NOTICE: currently INC loads model as float32, so a conversion should be done as a temporary solution
    model = model.to(model_dtype)
    
  3. Set the following when running your model. The SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false is an experimental flag which yields better performance. quantized_inc_model_path provides the path to INC quantized model checkpoint files:

    SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true <model run command> --model_name_or_path <model_name_or_path> --quantized_inc_model_path <tuned_checkpoint>
    

    Note

    • SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false will be removed in a future release.

    • model = model.to(model_dtype) will be removed in a future release.

Loading a Hugging Face WOQ Checkpoint by INC

You can load a Hugging Face checkpoint using the below steps. See the LlaMA 2 7B model for an example model using UINT4:

  1. Import habana_frameworks.torch.core:

    import habana_frameworks.torch.core as htcore
    
  2. Call the INC load API and target the Gaudi device:

    from neural_compressor.torch.quantization import load
    model = load(
        model_name_or_path=args.model_name_or_path,
        format="huggingface",
        device="hpu",
        **model_kwargs
    
  3. Set the following when running your model. The SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false is an experimental flag which yields better performance. --load_quantized_model_with_inc invokes the INC load API:

    SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false ENABLE_EXPERIMENTAL_FLAGS=true <model run command> --load_quantized_model_with_inc
    

    Note

    SRAM_SLICER_SHARED_MME_INPUT_EXPANSION_ENABLED=false will be removed in a future release.