Weight Sharing

Weight sharing is a technique in which the module weights are shared among two or more layers. Weights can be shared using PyTorch with Gaudi only if they are created inside the module. You can find an example of weight sharing in BERT Pre-Training example on GitHub.

import torch
import habana_frameworks.torch.core as ht

# Example module
class WeightShareModule(torch.nn.Module):
 def __init__(self):
     super(WeightShareModule, self).__init__()
     self.a = torch.nn.Parameter(torch.ones([2]))
     self.b = torch.nn.Parameter(torch.ones([2]))
 def forward(self, input):
     c = self.a*input + self.b*input
     return c

module = WeightShareModule()
#module.a and module.b are shared
module.a = module.b
# Move the module to HPU device
module.to("hpu")
Copy to clipboard