Intel Gaudi GPU Migration Toolkit APIs
On this Page
Intel Gaudi GPU Migration Toolkit APIs¶
This document provides a full list of Python API calls and their compatibility with Intel® Gaudi® AI accelerator. For further details on using the GPU Migration toolkit, refer to GPU Migration Toolkit.
GPU Migration APIs Support¶
The table below describes which APIs are supported by the GPU Migration toolkit. They are classified as hpu_match, hpu_modified and hpu_mismatch in the HPU implementation column. These terms are described in detail below.
hpu_match¶
Intel Gaudi implementations replicate the functionality of API calls which the GPU Migration toolkit maps as hpu_match. When the calls match functionally, no further configuration is required.
The GPU Migration toolkit maps GPU arguments to HPU arguments. For example:
x = torch.ones(1, device="cuda") # GPU Migration changes the argument `device` from "cuda" to "hpu".
The GPU Migration toolkit maps HPU output value to GPU output value. For example:
backend = torch.distributed.get_backend() # GPU migration changes the output backend from "hccl" to "nccl" to align with GPU-based code.
The GPU Migration toolkit maps GPU calls to HPU calls. For example:
torch.cuda.memory_usage() # GPU Migration maps 'torch.cuda.memory_usage' to 'torch.hpu.memory_usage'.
hpu_modified¶
Certain Gaudi-specific calls have been used to enable the replication of GPU-specific functionalities and guarantee compatibility with HPU.
For example, the GPU Migration toolkit maps torch.cuda.FloatTensor
to torch.FloatTensor
and torch.Tensor.to(“hpu”)
.
As a result, torch.cuda.FloatTensor
is no longer a tensor type, and can only be employed to create a tensor on HPU.
Please review the limitations stated in each call’s description in the table below. The number of calls in this category will decrease in future releases.
hpu_mismatch¶
These calls are currently incompatible with Gaudi implementation, so the GPU Migration toolkit either creates an inactive call or raises an exception.
Inactive Call - Since the functions cannot be executed on HPU, the GPU Migration toolkit converts them into inactive calls. For example:
torch.cuda.empty_cache()
NotImplementedError Exception - Certain API calls cannot be processed by either the GPU Migration toolkit or the Intel Gaudi software, causing the training process to be interrupted with a ‘NotImplementedError’ exception. For example:
torch.cuda.comm.broadcast
If you encounter one of hpu_mismatch scenarios described above, it is necessary to generate a log file, as described in Enabling GPU Migration Logging section, manually modify the mismatching functionality and restart the training.
The table below lists all the converted calls and explain their compatibility with HPU.
Call |
HPU Implementation |
Description |
---|---|---|
apex.amp.frontend.initialize |
hpu_modified |
Uses native torch.autocast with device_type=”hpu” instead. |
apex.normalization.fused_layer_norm.FusedLayerNorm |
hpu_modified |
Uses torch.nn.functional.layer_norm. |
apex.normalization.fused_layer_norm.FusedRMSNorm |
hpu_modified |
Uses python manual RMSNorm implementation; ignores elementwise_affine option. |
apex.normalization.fused_layer_norm.MixedFusedLayerNorm |
hpu_modified |
Uses torch.nn.functional.layer_norm. |
apex.normalization.fused_layer_norm.MixedFusedRMSNorm |
hpu_modified |
Maps to FusedRMSNorm. |
apex.optimizers.FusedLAMB |
hpu_match |
Maps apex.optimizers.FusedLamb to hpex.optimizers.FusedLamb. |
apex.optimizers.FusedAdagrad |
hpu_modified |
Does not support adagrad_w_mode. |
apex.optimizers.FusedAdam |
hpu_modified |
Uses hpex_FusedAdamW if adam_w_mode is True. Otherwise, it uses torch.otpim.Adam. |
apex.optimizers.FusedSGD |
hpu_modified |
Ignores wd_after_momentum and materialize_master_grads. |
apex.optimizers.FusedAdagrad.zero_grad |
hpu_match |
Maps apex.optimizers.FusedAdagrad.zero_grad to hpex.optimizers.FusedAdagrad.zero_grad. |
apex.optimizers.FusedAdam.zero_grad |
hpu_match |
Maps apex.optimizers.FusedAdam.zero_grad to hpex.optimizers.FusedAdamW.zero_grad if adam_w_mode is set to True, otherwise maps to torch.optim.Adam.zero_grad. |
apex.optimizers.FusedLAMB.zero_grad |
hpu_match |
Maps apex.optimizers.FusedLamb.zero_grad to hpex.optimizers.FusedLamb.zero_grad. |
apex.optimizers.FusedSGD.zero_grad |
hpu_match |
Maps apex.optimizers.FusedSGD.zero_grad to hpex.optimizers.FusedSGD.zero_grad. |
apex.parallel.DistributedDataParallel |
hpu_mismatch |
Raises NotImplementedError. Please use native torch.nn.DistributedDataParallel. |
pynvml.nvml.nvmlDeviceClearCpuAffinity |
hpu_match |
Maps pynvml.nvmlDeviceClearCpuAffinity to pyhlml.hlmlDeviceClearCpuAffinity. |
pynvml.nvml.nvmlDeviceGetClockInfo |
hpu_match |
Maps pynvml.nvmlDeviceGetClockInfo to pyhlml.hlmlDeviceGetClockInfo. |
pynvml.nvml.nvmlDeviceGetCount |
hpu_match |
Maps pynvml.nvmlDeviceGetCount to pyhlml.hlmlDeviceGetCount. |
pynvml.nvml.nvmlDeviceGetCpuAffinity |
hpu_match |
Maps pynvml.nvmlDeviceGetCpuAffinity to pyhlml.hlmlDeviceGetCpuAffinity. |
pynvml.nvml.nvmlDeviceGetCurrentClocksThrottleReasons |
hpu_match |
Maps pynvml.nvmlDeviceGetCurrentClocksThrottleReasons to pyhlml.hlmlDeviceGetCurrentClocksThrottleReasons. |
pynvml.nvml.nvmlDeviceGetHandleByIndex |
hpu_match |
Maps pynvml.nvmlDeviceGetHandleByIndex to pyhlml.hlmlDeviceGetHandleByIndex. |
pynvml.nvml.nvmlDeviceGetHandleByUUID |
hpu_match |
Maps pynvml.nvmlDeviceGetHandleByUUID to pyhlml.hlmlDeviceGetHandleByUUID. |
pynvml.nvml.nvmlDeviceGetMaxClockInfo |
hpu_match |
Maps pynvml.nvmlDeviceGetMaxClockInfo to pyhlml.hlmlDeviceGetMaxClockInfo. |
pynvml.nvml.nvmlDeviceGetMemoryErrorCounter |
hpu_match |
Maps pynvml.nvmlDeviceGetMemoryErrorCounter to pyhlml.hlmlDeviceGetMemoryErrorCounter. |
pynvml.nvml.nvmlDeviceGetMemoryInfo |
hpu_match |
Maps pynvml.nvmlDeviceGetMemoryInfo to pyhlml.hlmlDeviceGetMemoryInfo. |
pynvml.nvml.nvmlDeviceGetMinorNumber |
hpu_match |
Maps pynvml.nvmlDeviceGetMinorNumber to pyhlml.hlmlDeviceGetMinorNumber. |
pynvml.nvml.nvmlDeviceGetName |
hpu_match |
Maps pynvml.nvmlDeviceGetName to pyhlml.hlmlDeviceGetName. |
pynvml.nvml.nvmlDeviceGetPerformanceState |
hpu_match |
Maps pynvml.nvmlDeviceGetPerformanceState to pyhlml.hlmlDeviceGetPerformanceState. |
pynvml.nvml.nvmlDeviceGetPersistenceMode |
hpu_match |
Maps pynvml.nvmlDeviceGetPersistenceMode to pyhlml.hlmlDeviceGetPersistenceMode. |
pynvml.nvml.nvmlDeviceGetPowerManagementDefaultLimit |
hpu_match |
Maps pynvml.nvmlDeviceGetPowerManagementDefaultLimit to pyhlml.hlmlDeviceGetPowerManagementDefaultLimit. |
pynvml.nvml.nvmlDeviceGetPowerUsage |
hpu_match |
Maps pynvml.nvmlDeviceGetPowerUsage to pyhlml.hlmlDeviceGetPowerUsage. |
pynvml.nvml.nvmlDeviceGetSerial |
hpu_match |
Maps pynvml.nvmlDeviceGetSerial to pyhlml.hlmlDeviceGetSerial. |
pynvml.nvml.nvmlDeviceGetTemperature |
hpu_match |
Maps pynvml.nvmlDeviceGetTemperature to pyhlml.hlmlDeviceGetTemperature. |
pynvml.nvml.nvmlDeviceGetTemperatureThreshold |
hpu_match |
Maps pynvml.nvmlDeviceGetTemperatureThreshold to pyhlml.hlmlDeviceGetTemperatureThreshold. |
pynvml.nvml.nvmlDeviceGetTotalEnergyConsumption |
hpu_match |
Maps pynvml.nvmlDeviceGetTotalEnergyConsumption to pyhlml.hlmlDeviceGetTotalEnergyConsumption. |
pynvml.nvml.nvmlDeviceGetUUID |
hpu_match |
Maps pynvml.nvmlDeviceGetUUID to pyhlml.hlmlDeviceGetUUID. |
pynvml.nvml.nvmlDeviceGetUtilizationRates |
hpu_match |
Maps pynvml.nvmlDeviceGetUtilizationRates to pyhlml.hlmlDeviceGetUtilizationRates. |
pynvml.nvml.nvmlDeviceGetViolationStatus |
hpu_match |
Maps pynvml.nvmlDeviceGetViolationStatus to pyhlml.hlmlDeviceGetViolationStatus. |
pynvml.nvml.nvmlDeviceRegisterEvents |
hpu_match |
Maps pynvml.nvmlDeviceRegisterEvents to pyhlml.hlmlDeviceRegisterEvents. |
pynvml.nvml.nvmlDeviceSetCpuAffinity |
hpu_match |
Maps pynvml.nvmlDeviceSetCpuAffinity to pyhlml.hlmlDeviceSetCpuAffinity. |
pynvml.nvml.nvmlEventSetCreate |
hpu_match |
Maps pynvml.nvmlEventSetCreate to pyhlml.hlmlEventSetCreate. |
pynvml.nvml.nvmlEventSetFree |
hpu_match |
Maps pynvml.nvmlEventSetFree to pyhlml.hlmlEventSetFree. |
pynvml.nvml.nvmlEventSetWait |
hpu_match |
Maps pynvml.nvmlEventSetWait to pyhlml.hlmlEventSetWait. |
pynvml.nvml.nvmlInit |
hpu_match |
Maps pynvml.nvmlInit to pyhlml.hlmlInit. |
pynvml.nvml.nvmlShutdown |
hpu_match |
Maps pynvml.nvmlShutdown to pyhlml.hlmlShutdown. |
torch.autocast |
hpu_match |
Changes device_type to “hpu” and dtype to None, leaving the data type casting decision to hpu autocast engine. |
torch.compile |
hpu_match |
Maps inductor backend to hpu_backend. For further details, refer to PyTorch. |
torch.empty_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.rand_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.randint |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.randint_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.randn_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.zeros |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.zeros_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.arange |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.range |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.full |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.full_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.eye |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.ones |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.ones_like |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.tensor |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.as_tensor |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. |
torch.empty |
hpu_modified |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. Ignores pin_memory option. |
torch.randn |
hpu_modified |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. Ignores pin_memory option. |
torch.rand |
hpu_modified |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. Ignores pin_memory option. |
torch.empty_strided |
hpu_modified |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. Ignores pin_memory option. |
torch.randperm |
hpu_modified |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16 if PT_HPU_CONVERT_FP16_TO_BF16_FOR_MIGRATION=1. Ignores pin_memory option. |
torch.Tensor.cuda |
hpu_match |
Changes the function to torch.Tensor.to(“hpu”). |
torch.Tensor.half |
hpu_match |
Changes the function to torch.Tensor.to(torch.bfloat16). |
torch.Tensor.pin_memory |
hpu_match |
Pins memory to HPU device instead of GPU. |
torch.Tensor.to |
hpu_match |
Changes device arguments from “cuda” to “hpu” and dtype from torch.float16 to torch.bfloat16. |
torch.Tensor.type |
hpu_match |
Maps to torch.Tensor.type(“torch.<Scalar>Tensor”) and torch.Tensor.to(“hpu”) if dtype is torch.hpu.<Scalar>Tensor. If dtype is None this changes the tensor type to the same type as CUDA tensor. |
torch.Tensor.record_stream |
hpu_mismatch |
Inactive Call. |
torch.Tensor.numpy |
hpu_modified |
If tensor datatype is torch.bfloat16, changes the function to torch.Tensor.to(torch.float16) and torch.Tensor.numpy. |
torch._C._cuda_setDevice |
hpu_match |
Maps torch._C._cuda_setDevice to torch.hpu.set_device. |
torch._C._cuda_getDevice |
hpu_match |
Maps torch._C._cuda_getDevice to torch.hpu.current_device. |
torch._C._cuda_setStream |
hpu_mismatch |
Inactive Call. |
torch.amp.autocast |
hpu_match |
Changes device_type to “hpu” and dtype to None, leaving the data type casting decision to hpu autocast engine. |
torch.cuda.StreamContext |
hpu_match |
Maps torch.cuda.StreamContext to torch.hpu.StreamContext. |
torch.cuda.can_device_access_peer |
hpu_match |
Maps torch.cuda.can_device_access_peer to torch.hpu.can_device_access_peer. |
torch.cuda.current_device |
hpu_match |
Maps torch.cuda.current_device to torch.hpu.current_device. |
torch.cuda.current_stream |
hpu_match |
Maps torch.cuda.current_stream to torch.hpu.current_stream. |
torch.cuda.default_stream |
hpu_match |
Maps torch.cuda.default_stream to torch.hpu.default_stream. |
torch.cuda.device_count |
hpu_match |
Maps torch.cuda.device_count to torch.hpu.device_count. |
torch.cuda.get_device_name |
hpu_match |
Maps torch.cuda.get_device_name to torch.hpu.get_device_name. |
torch.cuda.get_gencode_flags |
hpu_match |
Maps torch.cuda.get_gencode_flags to torch.hpu.get_gencode_flags. |
torch.cuda.init |
hpu_match |
Maps torch.cuda.init to torch.hpu.init. |
torch.cuda.is_available |
hpu_match |
Maps torch.cuda.is_available to torch.hpu.is_available. |
torch.cuda.is_bf16_supported |
hpu_match |
Maps torch.cuda.is_bf16_supported to torch.hpu.is_bf16_supported. |
torch.cuda.is_initialized |
hpu_match |
Maps torch.cuda.is_initialized to torch.hpu.is_initialized. |
torch.cuda.memory_usage |
hpu_match |
Maps to torch.hpu.memory_usage(). |
torch.cuda.set_device |
hpu_match |
Maps to torch.hpu.set_device. |
torch.cuda.set_stream |
hpu_match |
Maps torch.cuda.set_stream to torch.hpu.set_stream. |
torch.cuda.set_sync_debug_mode |
hpu_match |
Maps torch.cuda.set_sync_debug_mode to torch.hpu.set_sync_debug_mode. |
torch.cuda.synchronize |
hpu_match |
Maps torch.cuda.synchronize to torch.hpu.synchronize. |
torch.cuda.utilization |
hpu_match |
Maps to torch.hpu.utilization(). |
torch.cuda.BFloat16Tensor |
hpu_modified |
Maps torch.cuda.BFloat16Tensor to torch.BFloat16Tensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.BoolTensor |
hpu_modified |
Maps torch.cuda.BoolTensor to torch.BoolTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.ByteTensor |
hpu_modified |
Maps torch.cuda.ByteTensor to torch.ByteTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.CharTensor |
hpu_modified |
Maps torch.cuda.CharTensor to torch.CharTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.CudaError |
hpu_modified |
Uses RuntimeError. |
torch.cuda.DoubleTensor |
hpu_modified |
Maps torch.cuda.DoubleTensor to torch.DoubleTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.FloatTensor |
hpu_modified |
Maps torch.cuda.FloatTensor to torch.FloatTensor + torch.Tensor.to(“hpu”). torch.cuda.FloatTensor is no longer the tensor type, and can only be used to create a tensor on HPU. |
torch.cuda.HalfTensor |
hpu_modified |
Maps torch.cuda.HalfTensor to torch.HalfTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.IntTensor |
hpu_modified |
Maps torch.cuda.IntTensor to torch.IntTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.LongTensor |
hpu_modified |
Maps torch.cuda.LongTensor to torch.LongTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.ShortTensor |
hpu_modified |
Maps torch.cuda.ShortTensor to torch.ShortTensor + torch.Tensor.to(“hpu”). See torch.cuda.FloatTensor for details. |
torch.cuda.check_error |
hpu_modified |
Uses RuntimeError. |
torch.cuda.cudart |
hpu_modified |
Returns None. |
torch.cuda.current_blas_handle |
hpu_modified |
Returns 0. |
torch.cuda.get_arch_list |
hpu_modified |
Returns a constructed list of all nvcc archs. |
torch.cuda.get_device_capability |
hpu_modified |
Returns latest cuda capability. |
torch.cuda.get_device_properties |
hpu_modified |
Returns a constructed _CudaDeviceProperties. |
torch.cuda.get_sync_debug_mode |
hpu_modified |
Uses PT_ENABLE_HABANA_STREAMASYNC environment variable to get get_sync_debug_mode. |
torch.cuda.ipc_collect |
hpu_modified |
Returns None. |
torch.cuda.stream |
hpu_modified |
Returns a torch.cuda.StreamContext. |
torch.cuda.amp.GradScaler |
hpu_modified |
Sets enabled argument to False. GradScaler prevents gradient values “underflow”, and used for ops with FP16 inputs. However, HPU uses BF16 in the training. |
torch.cuda.amp.GradScaler.state_dict |
hpu_match |
Returns the state of the scaler with values in disable mode. |
torch.cuda.comm.broadcast |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.broadcast instead. |
torch.cuda.comm.broadcast_coalesced |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.broadcast instead. |
torch.cuda.comm.gather |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.gather instead. |
torch.cuda.comm.reduce_add |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.reduce instead. |
torch.cuda.comm.reduce_add_coalesced |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.reduce instead. |
torch.cuda.comm.scatter |
hpu_mismatch |
Raises NotImplementedError. Please use torch.distributed.scatter instead. |
torch.cuda.graphs.CUDAGraph |
hpu_match |
Maps torch.cuda.CUDAGraph to habana_frameworks.torch.hpu.graphs.HPUGraph. |
torch.cuda.graphs.graph |
hpu_match |
Maps torch.cuda.graph to habana_frameworks.torch.hpu.graphs.graph. |
torch.cuda.graphs.is_current_stream_capturing |
hpu_match |
Maps torch.cuda.graphs.is_current_stream_capturing to torch.hpu.graphs.is_current_stream_capturing. |
torch.cuda.graphs.make_graphed_callables |
hpu_match |
Maps torch.cuda.make_graphed_callables to habana_frameworks.torch.hpu.graphs.make_graphed_callables. |
torch.cuda.graphs.graph_pool_handle |
hpu_mismatch |
Returns None. |
torch.cuda.memory.max_memory_allocated |
hpu_match |
Maps torch.cuda.max_memory_allocated to torch.hpu.max_memory_allocated. |
torch.cuda.memory.max_memory_cached |
hpu_match |
Maps torch.cuda.max_memory_cached to torch.hpu.max_memory_reserved. |
torch.cuda.memory.max_memory_reserved |
hpu_match |
Maps torch.cuda.max_memory_reserved to torch.hpu.max_memory_reserved. |
torch.cuda.memory.mem_get_info |
hpu_match |
Maps torch.cuda.mem_get_info to torch.hpu.mem_get_info. |
torch.cuda.memory.memory_allocated |
hpu_match |
Maps torch.cuda.memory_allocated to torch.hpu.memory_allocated. |
torch.cuda.memory.memory_cached |
hpu_match |
Maps torch.cuda.memory_cached to torch.hpu.memory_reserved. |
torch.cuda.memory.memory_reserved |
hpu_match |
Maps torch.cuda.memory_reserved to torch.hpu.memory_reserved. |
torch.cuda.memory.memory_stats |
hpu_match |
Maps torch.cuda.memory_stats to torch.hpu.memory_stats. |
torch.cuda.memory.memory_summary |
hpu_match |
Maps torch.cuda.memory_summary to torch.hpu.memory_summary. |
torch.cuda.memory.reset_max_memory_allocated |
hpu_match |
Maps torch.cuda.reset_max_memory_allocated to torch.hpu.reset_max_memory_allocated. |
torch.cuda.memory.reset_peak_memory_stats |
hpu_match |
Maps torch.cuda.reset_peak_memory_stats to torch.hpu.reset_peak_memory_stats.nnDescription:Resets the peak stats tracked by the memory allocator. |
torch.cuda.memory.caching_allocator_alloc |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.caching_allocator_delete |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.empty_cache |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.memory_snapshot |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.reset_max_memory_cached |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.set_per_process_memory_fraction |
hpu_mismatch |
Inactive Call. |
torch.cuda.memory.list_gpu_processes |
hpu_modified |
Prints out the running processes and their HPU memory use by calling torch.hpu.memory_reserved for the given device. |
torch.cuda.nvtx.range |
hpu_mismatch |
Inactive Call. |
torch.cuda.random.get_rng_state |
hpu_match |
Maps torch.cuda.random.get_rng_state to torch.hpu.random.get_rng_state. |
torch.cuda.random.get_rng_state_all |
hpu_match |
Maps torch.cuda.random.get_rng_state_all to torch.hpu.random.get_rng_state_all. |
torch.cuda.random.initial_seed |
hpu_match |
Maps torch.cuda.random.initial_seed to torch.hpu.random.initial_seed. |
torch.cuda.random.manual_seed |
hpu_match |
Maps torch.cuda.random.manual_seed to torch.hpu.random.manual_seed. |
torch.cuda.random.manual_seed_all |
hpu_match |
Maps torch.cuda.random.manual_seed_all to torch.hpu.random.manual_seed_all. |
torch.cuda.random.seed |
hpu_match |
Maps torch.cuda.random.seed to torch.hpu.random.seed. |
torch.cuda.random.seed_all |
hpu_match |
Maps torch.cuda.random.seed_all to torch.hpu.random.seed_all. |
torch.cuda.random.set_rng_state |
hpu_match |
Maps torch.cuda.random.set_rng_state to torch.hpu.random.set_rng_state. |
torch.cuda.random.set_rng_state_all |
hpu_match |
Maps torch.cuda.random.set_rng_state_all to torch.hpu.random.set_rng_state_all. |
torch.cuda.streams.Stream |
hpu_match |
Maps torch.cuda.Stream to torch.hpu.Stream. |
torch.cuda.streams.ExternalStream |
hpu_modified |
Maps torch.cuda.Stream.ExternalStream to torch.hpu.Stream. |
torch.distributed.barrier |
hpu_match |
Ignores device_ids, which is valid only for NCCL backend. |
torch.distributed.get_backend |
hpu_match |
Returns nccl as a backend. |
torch.distributed.init_process_group |
hpu_match |
Changes backend from nccl to hccl. |
torch.distributed.is_nccl_available |
hpu_match |
Checks import habana_frameworks.torch.distributed.hccl. If success, return True; else return False. |
torch.distributed.c10d_logger._get_msg_dict |
hpu_match |
Removes nccl_version from msg_dict. |
torch.distributed.fsdp.FullyShardedDataParallel |
hpu_match |
Changes device_id arguments from “cuda” to “hpu”. |
torch.nn.DataParallel |
hpu_mismatch |
Raises NotImplementedError. Please use torch.nn.DistributedDataParallel instead. |
torch.nn.Embedding |
hpu_modified |
When using sparse gradients, the whole layer and inputs are moved to CPU. |
torch.nn.EmbeddingBag |
hpu_modified |
When using sparse gradients, the whole layer and inputs are moved to CPU. |
torch.nn.functional.scaled_dot_product_attention |
hpu_modified |
Uses Intel Gaudi FusedSDPA when any of math_sdp_enabled (if dropout==0.0), mem_efficient_sdp_enabled and flash_sdp_enabled are True and either torch level reference implementation (if dropout >= 0.0) or native torch.nn.scaled_dot_product_attention otherwise. More information available under: https://docs.habana.ai/en/latest/PyTorch/Python_Packages.html#hpex-kernels-fusedsdpa |
torch.nn.parallel.DistributedDataParallel |
hpu_match |
Sets device_ids and output_device to None. |
torch.ops.aten._scaled_dot_product_flash_attention |
hpu_modified |
Uses Habana FusedSDPA with enable_recompute = False. |
torch.ops.aten._scaled_dot_product_efficient_attention |
hpu_modified |
Uses Habana FusedSDPA with enable_recompute = True. Ignores return_debug_mask and dropout_mask arguments. |
torch.ops.aten._scaled_dot_product_attention_math |
hpu_modified |
Uses torch level reference implementation. |
torch.optim.Adadelta |
hpu_modified |
Sets foreach parameter to false. |
torch.optim.Adam |
hpu_modified |
Sets foreach parameter to false. |
torch.optim.SGD |
hpu_modified |
Sets foreach parameter to false. |
torch.optim.RMSprop.RMSProp |
hpu_modified |
Sets foreach parameter to false. |
torch.serialization.load |
hpu_match |
Changes map_location to enable HPU mapping instead of CUDA mapping. |
torch.utils.DeviceContext |
hpu_match |
Changes device_type of “cuda” to “hpu” in Device Context (‘with’ statement). |
torch.utils.data.DataLoader |
hpu_match |
If pin_memory is True and pin_memory_device is None/CUDA, sets pin_memory_device to “hpu”. |