Weight SharingΒΆ

Weight sharing is a technique in which the module weights are shared among two or more layers. When using PyTorch with Gaudi, your model may require the weights to be shared after moving the model to the HPU device as shown below. You can find an example of weight sharing in BERT Pre-Training example on GitHub.

import torch
import habana_frameworks.torch.core as ht

# Example module
class WeightShareModule(torch.nn.Module):
 def __init__(self):
     super(WeightShareModule, self).__init__()
     self.a = torch.nn.Parameter(torch.ones([2]))
     self.b = torch.nn.Parameter(torch.ones([2]))
 def forward(self, input):
     c = self.a*input + self.b*input
     return c

module = WeightShareModule()
# Move the module to HPU device
# Weight sharing after the module is moved to HPU device
module.a = module.b