Fused Optimizers and Custom Ops for Intel Gaudi
On this Page
Fused Optimizers and Custom Ops for Intel Gaudi¶
The Intel® Gaudi® AI accelerator provides its own implementation of complex PyTorch ops customized for Gaudi devices. Replacing these complex ops with custom Gaudi versions enhances model performance.
Fused Optimizers¶
The following fused optimizers are supported:
FusedAdagrad
FusedAdamW
FusedEMA
FusedLamb
FusedSGD
Functional FusedAdamW
FusedLars
Following is an example demonstrating optimizer usage with the FusedAdagrad optimizer:
import torch
import torch.nn as nn
from habana_frameworks.torch.hpex.optimizers import FusedAdagrad
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
# Define the optimizer with all parameters
optimizer = FusedAdagrad(
model.parameters(),
lr=0.01,
lr_decay=0.001,
weight_decay=0.01,
initial_accumulator_value=0.1,
eps=1e-10
)
# Dummy input and target
input = torch.randn(32, 10)
target = torch.randn(32, 1)
# Define a loss function
loss = nn.MSELoss()
# Forward pass
output = model(input)
loss = loss(output, target)
# Zero the parameter gradients and call backward
optimizer.zero_grad()
loss.backward()
# Perform a single optimization step
optimizer.step()
FusedAdagrad¶
FusedAdagrad is a fused implementation of the Adagrad optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.Adagrad:
-
class
habana_frameworks.torch.hpex.optimizers.
FusedAdagrad
(params, lr=0.01, lr_decay=0, weight_decay=0, initial_accumulator_value=0, eps=1e-10)¶ - Parameters
params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.
lr – (float, optional) - Learning rate (default: 1e-2).
lr_decay – (float, optional) - Learning rate decay (default: 0).
weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).
initial_accumulator_value – (float, optional) - Initial accumulator value (default: 0).
eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-10).
FusedAdamW¶
FusedAdamW is a fused implementation of the AdamW optimizer for Gaudi devices. Refer to the original AdamW from Hugging Face optimizer documentation:
-
class
habana_frameworks.torch.hpex.optimizers.
FusedAdamW
(params, lr=0.001, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, bias_correction=True, moments_dtype=None)¶ - Parameters
params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.
lr – (float, optional) - Learning rate (default: 1e-3).
betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).
eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).
weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0.0).
bias_correction – (bool, optional) - Whether to use bias correction (default: True).
moments_dtype – (Optional[Union[torch.dtype, Tuple[torch.dtype, torch.dtype]]], optional) - Data type for moments (default: None).
FusedEMA¶
FusedEMA is a fused implementation of the EMA optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.swa_utils.AveragedModel:
-
class
habana_frameworks.torch.hpex.movingavrg.
FusedEMA
(model, decay=0.9999, updates=0)¶ - Parameters
model – (nn.Module) - Model to use with EMA
decay – (float, optional) - Decay parameter, scale to exponential function ‘decay * (1 - exp(-x / 2000))’ (default: 0.9999).
updates – (float, optional) - Counter incremented by 1 every update. Input ‘x’ to the above exponential function. (default: 0).
FusedLamb¶
Implements a version of LAMB optimizer customized for Gaudi devices. LAMB is proposed in Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
-
class
habana_frameworks.torch.hpex.optimizers.
FusedLamb
(params, lr=0.001, bias_correction=True, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, amsgrad=False, adam_w_mode=True, grad_averaging=True, set_grad_none=True, max_grad_norm=1.0, use_lamb=False, fused=False, dtype=None)¶ - Parameters
params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.
lr – (float, optional) - Learning rate (default: 1e-3).
bias_correction – (bool, optional) - Whether to use bias correction (default: True).
betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its norm (default: (0.9, 0.999)).
eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).
weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).
amsgrad – (boolean, optional) - Whether to use the AMSGrad variant of this algorithm (default: False).
adam_w_mode – (boolean, optional) - Apply L2 regularization or weight decay. True for decoupled weight decay (also known as AdamW) (default: True).
grad_averaging – (bool, optional) - Whether to apply (1-beta2) to grad when calculating running averages of gradient (default: True).
set_grad_none – (bool, optional) - Whether to set grad to None when zero_grad() method is called (default: True).
max_grad_norm – (float, optional) - Value used to clip global grad norm. If set to None, prior gradient clipping is disabled (default: 1.0).
use_lamb – (boolean, optional) - When set to True, force calculation of trust_ratio used for gradient update when weight_decay is 0 (default: False).
dtype – (torch.dtype, optional) - The desired data type of the parameters (grads). If specified, the parameters will be cast to this data type. (default: None)
FusedSGD¶
FusedSGD is a fused implementation of the SGD optimizer for Gaudi devices. Refer to the original PyTorch op documentation - torch.optim.SGD:
-
class
habana_frameworks.torch.hpex.optimizers.
FusedSGD
(params, lr, momentum=0, dampening=0, weight_decay=0, nesterov=False)¶ - Parameters
params – (iterable) - Iterable of parameters to optimize or dicts defining parameter groups.
lr – (float) - Learning rate.
momentum – (float, optional) - Momentum factor (default: 0).
weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0).
dampening – (float, optional) - Dampening for momentum (default: 0).
nesterov – (bool, optional) - Enables Nesterov momentum (default: False).
Functional FusedAdamW¶
Functional FusedAdamW is a functional implementation of the AdamW optimizer for Gaudi devices. This functional version of FusedAdamW is based on torch.distributed.optim._FunctionalAdamW. It can be enabled with
habana_frameworks.torch.hpex.optimizers.distributed.FusedAdamW
:
-
class
habana_frameworks.torch.hpex.optimizers.distributed.
FusedAdamW
(params, lr=0.001, betas=0.9, 0.999, eps=1e-06, weight_decay=0.0, _allow_empty_param_list=False, moments_dtype=None)¶ - Parameters
params – (List[torch.Tensor]) - List of parameters to optimize or dicts defining parameter groups.
lr – (float, optional) - Learning rate (default: 1e-3).
betas – (Tuple[float, float], optional) - Coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)).
eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-6).
weight_decay – (float, optional) - Weight decay (L2 penalty) (default: 0.0).
_allow_empty_param_list – (bool, optional) - Retained for PyTorch compatibility (default: False).
moments_dtype – (Optional[Union[torch.dtype, Tuple[torch.dtype, torch.dtype]]], optional) - Data type for moments (default: None).
FusedLars¶
FusedLars is a fused implementation of the LARS optimizer for Gaudi devices. For more details, refer to the LARS optimizer paper:
-
class
habana_frameworks.torch.hpex.optimizers.
FusedLars
(optimizer, skip_mask, eeta=0.001, eps=1e-08)¶ - Parameters
optimizer – (torch.optim.Optimizer) - The base optimizer to be wrapped by FusedLars.
skip_mask – (torch.Tensor) - Mask to skip certain layers from LARS scaling.
a – (float, optional) - LARS coefficient as used in the paper (default: 0.001).
eps – (float, optional) - Term added to the denominator to improve numerical stability (default: 1e-8).
Note
For models using Lazy mode execution, mark_step()
must be added right after loss.backward()
and optimizer.step()
.
Custom Ops¶
FusedClipNorm¶
A class to perform gradient clipping by norm for parameters on Habana devices. Only norm_type 2.0 is supported. Refer to torch.nn.utils.clip_grad_norm_ for more details.
-
class
habana_frameworks.torch.hpex.normalization.
FusedClipNorm
(parameters: Iterable[torch.nn.parameter.Parameter], max_norm: float)¶ - Parameters
parameters – (Iterable[torch.nn.parameter.Parameter]) - The parameters whose gradients will be clipped.
max_norm – (float) - The maximum norm value to clip gradients to.
-
clip_norm
(parameters: Union[Iterable[torch.nn.parameter.Parameter], torch.Tensor])¶ Clips the gradients of the given parameters to the maximum norm value.
- Parameters
parameters – Union[Iterable[torch.nn.parameter.Parameter], torch.Tensor] - The parameters whose gradients will be clipped. If a single tensor is provided, its gradient will be clipped.
try:
from habana_frameworks.torch.hpex.normalization import FusedClipNorm
except ImportError:
raise ImportError("Please install habana_torch package")
FusedNorm = FusedClipNorm(model.parameters(), args.max_grad_norm)
FusedNorm.clip_norm(model.parameters())
Mixture of Experts Forward (MoE)¶
Note
This feature is supported on Gaudi 3 and Gaudi 2 only.
This custom op is designed to replace the MoE block in Mixtral and LLaMA models. In Eager mode, this op is lowered into basic ops,
which are used in the original Hugging Face implementation. Refer to HuggingFace modelling_mixtral.py for more details. In Lazy and torch.compile
modes, it utilizes a dedicated kernel optimized for Gaudi. This design ensures that computational resources are used efficiently, leading to faster execution times and improved overall performance. The following data types are supported: float32, float16 and bfloat16.
The op supports the following:
MLP/Feed Forward structure of Mixtral and LLaMA, replacing the experts blocks along with the subsequent weighted sum in case where a token is sent to multiple experts.
Both the HuggingFace flavor of the MLP, where all 3 GEMM operations are separate, as well as the vLLM use case, where the first two GEMM operations are fused.
Each expert weight numbering is based on the order of the corresponding Linear operations and differs from the Mixtral HF/vLLM numbering. This means that for HF Mixtral, w2 and w3 are swapped, and for vLLM Mixtral, w13 is replaced by w12. In the HF LLaMA notation, w1 corresponds to gate_proj, w2 corresponds to up_proj, and w3 corresponds to down_proj.
-
class
mixture_of_experts
(hidden_states, expert_routing_table, router_weights, w1, w2, w12, w3, permuted_weights, activation, experts_min, experts_max)¶ - Input hidden_states
(Tensor) - The input tensor containing the hidden states that will be processed by the experts.
- Input expert_routing_table
(Tensor) - A tensor that maps each input to the corresponding experts that will process it.
- Input router_weights
(Tensor) - Weights used by the router to determine the routing probabilities for each expert.
- Input w1
(TensorList, optional) - Expert weights for the first matrix multiplication operation for non fused GEMM flavor. TensorList size is equal to the number of experts.
- Input w2
(TensorList, optional) - Expert weights for the second matrix multiplication operation for non fused GEMM flavor. This can be concatenated with
w1
for compatibility with different frameworks. TensorList size is equal to the number of experts.- Input w12
(TensorList, optional) - Expert weights for the first matrix multiplication operation for fused GEMM flavor. TensorList size is equal to the number of experts.
- Input w3
(TensorList) - Expert weights for the last matrix multiplication operation. TensorList size is equal to the number of experts.
- Parameters
permuted_weights – (bool) - flag used to specify if expert weights are already permuted.
activation – (c10::string_view) - activation function used in MoE block. Supported activations are
gelu
,relu
andsilu
.experts_min – (int64_t) - used for device parallelism to support expert parallelism. It specifies for each device the experts subset it is responsible for. This is inclusive.
experts_max – (int64_t) - used for device parallelism to support expert parallelism. It specifies for each device the experts subset it is responsible for. This is inclusive.
- Output
(Tensor) - The output of MoE block. Shape is equal to input shape.
import habana_frameworks.torch.core as htcore
import torch
import torch.nn.functional as F
dtype = torch.float
activation = "gelu"
hidden_dim = 64
ffn_dim = 224
num_experts = 8
num_tokens = 32
fused_weights = False
permuted_weights = False
k = 2
hidden_states = torch.randn((num_tokens, hidden_dim), dtype=dtype)
score = torch.randn((num_tokens, num_experts), dtype=torch.float32)
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
router_weights, expert_routing_table = torch.topk(routing_weights,
k,
dim=-1)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(dtype=dtype)
w1 = [torch.randn((hidden_dim, ffn_dim), dtype=dtype).to("hpu") for _ in range(num_experts)]
w2 = [torch.randn((hidden_dim, ffn_dim), dtype=dtype).to("hpu") for _ in range(num_experts)]
w3 = [torch.randn((ffn_dim, hidden_dim), dtype=dtype).to("hpu") for _ in range(num_experts)]
result = torch.ops.hpu.mixture_of_experts(
hidden_states.to("hpu"),
expert_routing_table.to("hpu"),
router_weights.to("hpu"),
w1,
w2,
w3,
permuted_weights,
activation,
0,
num_experts - 1,
)
import habana_frameworks.torch.core as htcore
import torch
import torch.nn.functional as F
dtype = torch.float
activation = "gelu"
hidden_dim = 64
ffn_dim = 224
num_experts = 8
num_tokens = 32
fused_weights = False
permuted_weights = False
k = 2
hidden_states = torch.randn((num_tokens, hidden_dim), dtype=dtype)
score = torch.randn((num_tokens, num_experts), dtype=torch.float32)
routing_weights = F.softmax(score, dim=1, dtype=torch.float32)
router_weights, expert_routing_table = torch.topk(routing_weights,
k,
dim=-1)
router_weights /= router_weights.sum(dim=-1, keepdim=True)
router_weights = router_weights.to(dtype=dtype)
w12 = [torch.randn((hidden_dim, 2 * ffn_dim), dtype=dtype).to("hpu") for _ in range(num_experts)]
w3 = [torch.randn((ffn_dim, hidden_dim), dtype=dtype).to("hpu") for _ in range(num_experts)]
result = torch.ops.hpu.mixture_of_experts(
hidden_states.to("hpu"),
expert_routing_table.to("hpu"),
router_weights.to("hpu"),
w12,
w3,
permuted_weights,
activation,
0,
num_experts - 1,
)