Fused Optimizers and Custom Ops for Intel Gaudi

The Intel® Gaudi® AI accelerator provides its own implementation of complex PyTorch ops customized for Gaudi devices. Replacing these complex ops with custom Gaudi versions enhances model performance.

Fused Optimizers

The following fused optimizers are supported:

  • FusedAdagrad

  • FusedAdamW

  • FusedEMA

  • FusedLamb

  • FusedSGD

  • Functional FusedAdamW

  • FusedLars

Following is an example demonstrating optimizer usage with the FusedAdagrad optimizer:

import torch
import torch.nn as nn
from habana_frameworks.torch.hpex.optimizers import FusedAdagrad

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# Define the optimizer with all parameters
optimizer = FusedAdagrad(
    model.parameters(),
    lr=0.01,
    lr_decay=0.001,
    weight_decay=0.01,
    initial_accumulator_value=0.1,
    eps=1e-10
)

# Dummy input and target
input = torch.randn(32, 10)
target = torch.randn(32, 1)

# Define a loss function
loss = nn.MSELoss()

# Forward pass
output = model(input)
loss = loss(output, target)

# Zero the parameter gradients and call backward
optimizer.zero_grad()
loss.backward()

# Perform a single optimization step
optimizer.step()

FusedAdagrad

FusedAdagrad is a fused implementation of the Adagrad optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.Adagrad:

class habana_frameworks.torch.hpex.optimizers.FusedAdagrad(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)
Parameters
  • params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.

  • lr – (float, optional) - Learning rate (default: 1e-2).

  • lr_decay – (float, optional) - Learning rate decay (default: 0).

  • weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).

  • initial_accumulator_value – (float, optional) - Initial accumulator value (default: 0).

  • eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-10).

FusedAdamW

FusedAdamW is a fused implementation of the AdamW optimizer for Gaudi devices. Refer to the original AdamW from Hugging Face optimizer documentation:

class habana_frameworks.torch.hpex.optimizers.FusedAdamW(params, lr=0.001, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, bias_correction=True, moments_dtype=None)
Parameters
  • params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.

  • lr – (float, optional) - Learning rate (default: 1e-3).

  • betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).

  • eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).

  • weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0.0).

  • bias_correction – (bool, optional) - Whether to use bias correction (default: True).

  • moments_dtype – (Optional[Union[torch.dtype, Tuple[torch.dtype, torch.dtype]]], optional) - Data type for moments (default: None).

FusedEMA

FusedEMA is a fused implementation of the EMA optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.swa_utils.AveragedModel:

class habana_frameworks.torch.hpex.movingavrg.FusedEMA(model, decay=0.9999, updates=0)
Parameters
  • model – (nn.Module) - Model to use with EMA

  • decay – (float, optional) - Decay parameter, scale to exponential function ‘decay * (1 - exp(-x / 2000))’ (default: 0.9999).

  • updates – (float, optional) - Counter incremented by 1 every update. Input ‘x’ to the above exponential function. (default: 0).

FusedLamb

Implements a version of LAMB optimizer customized for Gaudi devices. LAMB is proposed in Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:

class habana_frameworks.torch.hpex.optimizers.FusedLamb(params, lr=0.001, bias_correction=True, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, amsgrad=False, adam_w_mode=True, grad_averaging=True, set_grad_none=True, max_grad_norm=1.0, use_lamb=False, fused=False, dtype=None)
Parameters
  • params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.

  • lr – (float, optional) - Learning rate (default: 1e-3).

  • bias_correction – (bool, optional) - Whether to use bias correction (default: True).

  • betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its norm (default: (0.9, 0.999)).

  • eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).

  • weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).

  • amsgrad – (boolean, optional) - Whether to use the AMSGrad variant of this algorithm (default: False).

  • adam_w_mode – (boolean, optional) - Apply L2 regularization or weight decay. True for decoupled weight decay (also known as AdamW) (default: True).

  • grad_averaging – (bool, optional) - Whether to apply (1-beta2) to grad when calculating running averages of gradient (default: True).

  • set_grad_none – (bool, optional) - Whether to set grad to None when zero_grad() method is called (default: True).

  • max_grad_norm – (float, optional) - Value used to clip global grad norm. If set to None, prior gradient clipping is disabled (default: 1.0).

  • use_lamb – (boolean, optional) - When set to True, force calculation of trust_ratio used for gradient update when weight_decay is 0 (default: False).

  • dtype – (torch.dtype, optional) - The desired data type of the parameters (grads). If specified, the parameters will be cast to this data type. (default: None)

FusedSGD

FusedSGD is a fused implementation of the SGD optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.SGD:

class habana_frameworks.torch.hpex.optimizers.FusedSGD(params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False)
Parameters
  • params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.

  • lr – (float) - Learning rate.

  • momentum – (float, optional) - Momentum factor (default: 0).

  • weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).

  • dampening – (float, optional) - Dampening for momentum (default: 0).

  • nesterov – (bool, optional) - Enables Nesterov momentum (default: False).

Functional FusedAdamW

Functional FusedAdamW is a functional implementation of the AdamW optimizer for Gaudi devices. This functional version of FusedAdamW is based on torch.distributed.optim._FunctionalAdamW. It can be enabled with habana_frameworks.torch.hpex.optimizers.distributed.FusedAdamW:

class habana_frameworks.torch.hpex.optimizers.distributed.FusedAdamW(params, lr=0.001, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, _allow_empty_param_list=False, moments_dtype=None)
Parameters
  • params – (List[torch.Tensor]) - List of parameters to optimize or dicts defining parameter groups.

  • lr – (float, optional) - Learning rate (default: 1e-3).

  • betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).

  • eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).

  • weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0.0).

  • _allow_empty_param_list – (bool, optional) - Retained for PyTorch compatibility (default: False).

  • moments_dtype – (Optional[Union[torch.dtype, Tuple[torch.dtype, torch.dtype]]], optional) - Data type for moments (default: None).

FusedLars

FusedLars is a fused implementation of the LARS optimizer for Gaudi devices. For more details, refer to the LARS optimizer paper:

class habana_frameworks.torch.hpex.optimizers.FusedLars(optimizer, skip_mask, eeta=0.001, eps=1e-08)
Parameters
  • optimizer – (torch.optim.Optimizer) - The base optimizer to be wrapped by FusedLars.

  • skip_mask – (torch.Tensor) - Mask to skip certain layers from LARS scaling.

  • a – (float, optional) - LARS coefficient as used in the paper (default: 0.001).

  • eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-8).

Note

For models using Lazy mode execution, mark_step() must be added right after loss.backward() and optimizer.step().

Custom Ops

FusedClipNorm

A class to perform gradient clipping by norm for parameters on Habana devices. Only norm_type 2.0 is supported. Refer to torch.nn.utils.clip_grad_norm_ for more details.

class habana_frameworks.torch.hpex.normalization.FusedClipNorm(parameters: Iterable[torch.nn.parameter.Parameter], max_norm: float)
Parameters
  • parameters – (Iterable[torch.nn.parameter.Parameter]) - The parameters whose gradients will be clipped.

  • max_norm – (float) - The maximum norm value to clip gradients to.

clip_norm(parameters: Union[Iterable[torch.nn.parameter.Parameter], torch.Tensor])

Clips the gradients of the given parameters to the maximum norm value.

Parameters

parameters – Union[Iterable[torch.nn.parameter.Parameter], torch.Tensor] - The parameters whose gradients will be clipped. If a single tensor is provided, its gradient will be clipped.

try:
   from habana_frameworks.torch.hpex.normalization import FusedClipNorm
except ImportError:
   raise ImportError("Please install habana_torch package")
   FusedNorm = FusedClipNorm(model.parameters(), args.max_grad_norm)

FusedNorm.clip_norm(model.parameters())

Mixture of Experts Forward (MoE)

This custom op is designed to replace the MoE block in Mixtral/LlaMA models. It matches results from the Hugging Face reference implementation, with more tokens included. Specific elements may diverge, so the custom op was tested with cosine similarity instead of elementwise comparison. Refer to HuggingFace modelling_mixtral.py for more details.

In Eager mode, this op is lowered into basic ops, which are used in the original Hugging Face implementation. In Lazy and torch.compile modes, it utilizes a dedicated kernel optimized for Gaudi. This design ensures that computational resources are used efficiently, leading to faster execution times and improved overall performance. The following data types are supported: float32, float16 and bfloat16.

class mixture_of_experts(hidden_states, expert_routing_table, router_weights, w1, w2, w3, permuted_weights, activation, experts_min, experts_max)
Input hidden_states

(Tensor) - The input tensor containing the hidden states that will be processed by the experts.

Input expert_routing_table

(Tensor) - A tensor that maps each input to the corresponding experts that will process it.

Input router_weights

(Tensor) - Weights used by the router to determine the routing probabilities for each expert.

Input w1

(TensorList) - Expert weights for the first matrix multiplication operation. TensorList size is equal to the number of experts.

Input w2

(TensorList, optional) - Expert weights for the second matrix multiplication operation. This can be concatenated with w1 for compatibility with different frameworks. TensorList size is equal to the number of experts.

Input w3

(TensorList) - Expert weights for the third matrix multiplication operation. TensorList size is equal to the number of experts.

Parameters
  • permuted_weights – (bool) - flag used to specify if expert weights are already permuted.

  • activation – (c10::string_view) - activation function used in MoE block. Supported activations are gelu, relu and silu.

  • experts_min – (int64_t) - used for device parallelism to support expert parallelism. It specifies for each device the experts subset it is responsible for.

  • experts_max – (int64_t) - used for device parallelism to support expert parallelism. It specifies for each device the experts subset it is responsible for.

Output

(Tensor) - The output of MoE block. Shape is equal to input shape.

dtype = torch.float
activation = "gelu"
hidden_dim = 64
ffn_dim = 224
num_experts = 8
num_tokens = 32
fused_weights = False
permuted_weights = False
k = 2
hidden_states = torch.randn((num_tokens, hidden_dim), dtype=dtype)
score = torch.randn((num_tokens, num_experts), dtype=torch.float32)
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
router_weights, expert_routing_table = torch.topk(routing_weights
                                              k,
                                              dim=-1)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(dtype=dtype)
w1 = [torch.randn((hidden_dim, ffn_dim).to(hpu), dtype=dtype) for _ in range(num_experts)]
w2 = [torch.randn((hidden_dim, ffn_dim).to(hpu), dtype=dtype) for _ in range(num_experts)]
w3 = [torch.randn((ffn_dim, hidden_dim).to(hpu), dtype=dtype) for _ in range(num_experts)]

result = torch.ops.hpu.mixture_of_experts(
    hidden_states.to(hpu),
    expert_routing_table.to(hpu),
    router_weights.to(hpu),
    w1,
    w2,
    w3,
    permuted_weights,
    activation,
    0,
    num_experts - 1,
)
  dtype = torch.float
  activation = "gelu"
  hidden_dim = 64
  ffn_dim = 224
  num_experts = 8
  num_tokens = 32
  fused_weights = False
  permuted_weights = False
  k = 2
  hidden_states = torch.randn((num_tokens, hidden_dim), dtype=dtype)
  score = torch.randn((num_tokens, num_experts), dtype=torch.float32)
  routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
  router_weights, expert_routing_table = torch.topk(routing_weights
                                                k,
                                                dim=-1)
  router_weights /= router_weights.sum(dim=-1, keepdim=True)
  router_weights = router_weights.to(dtype=dtype)
  w12 = [torch.randn((hidden_dim, ffn_dim).to(hpu), dtype=dtype) for _ in range(num_experts)]
  w3 = [torch.randn((ffn_dim, 2 * hidden_dim).to(hpu), dtype=dtype) for _ in range(num_experts)]

  result = torch.ops.hpu.mixture_of_experts(
    hidden_states.to(hpu),
    expert_routing_table.to(hpu),
    router_weights.to(hpu),
    w12,
    w3,
    permuted_weights,
    activation,
    0,
    num_experts - 1,
)