Horovod-based Scaling of Gaudi on TensorFlow

mpirun Configuration

mpirun map-by PE attribute value may vary on your setup and should be calculated as: socket:PE = floor((number of physical cores) / (number of gaudi devices per each node)).

This sample code can also be used to calculate the number of physical CPU cores and HPU count to generate the appropriate PE value, shown as MPI_PE below. This can be incorporated into any model:

export PHY_CPU_COUNT=$(lscpu --all --parse=CORE,SOCKET | grep -Ev "^#" | sort -u | wc -l)
export PHY_HPU_COUNT=$(ls /dev/hl? | wc -l)
export MPI_PE=$(($PHY_CPU_COUNT/$PHY_HPU_COUNT))
Copy to clipboard

The PE value in the Model-References examples may be set to a common number to ensure functionality, but depending on the Host CPU, the directions above should be used for optimal system performance.

Scale-up Using Gaudi NICs Within a Server

The below is a simple example of distributed training and is based on the single Gaudi training example detailed in the Porting a Simple TensorFlow Model to Gaudi. The training model and the corresponding scripts are available in the TensorFlow Hello World Example on Github.

The highlighted lines of code are added for distributed training.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import tensorflow as tf
from habana_frameworks.tensorflow import load_habana_module
load_habana_module()
import horovod.tensorflow.keras as hvd
#Initialization of Horovod. 
hvd.init()

# Ensure only 1 process downloads the data on each node
if hvd.local_rank() == 0:
	(x_train, y_train), (x_test, y_test) =
	tf.keras.datasets.mnist.load_data()
	hvd.broadcast(0, 0)
else:
	hvd.broadcast(0, 0)
	(x_train, y_train), (x_test, y_test) =
	tf.keras.datasets.mnist.load_data()

# Data partition for different workers
num_pics_per_rank = x_train.shape[0] // hvd.size()
pic_begin = num_pics_per_rank * hvd.rank()
pic_end = pic_begin + num_pics_per_rank
x_train = x_train[pic_begin:pic_end,]
y_train = y_train[pic_begin:pic_end,]

x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(10)
])
loss =
tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Using hvd.size()(number of workers) to scale learning rate and wrapping
# optimizer with Distributed optimizer class provided by horovod.
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01*hvd.size())
optimizer =hvd.DistributedOptimizer(optimizer)

callbacks = [
# Horovod: broadcast initial variable states from rank0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
]
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy'])
model.fit(x_train, y_train, epochs=1, batch_size=128,
callbacks=callbacks)

model.evaluate(x_test, y_test)
Copy to clipboard

The code above runs in multiple processes, one for each Gaudi.

In order to launch the distributed training for eight Gaudi devices within one host run the following command.

Note

Open MPI is required for host communication and launching processes version. For Open MPI updated version, refer to Support Matrix.

$ mpirun -np 8 python3 example_hvd.py
Copy to clipboard

The below is an example output:

7500/7500 [==============================] - 104s 14ms/sample - loss:
0.7289 - accuracy: 0.8361

7500/7500 [==============================] - 104s 14ms/sample - loss:
0.7916 - accuracy: 0.8051

7500/7500 [==============================] - 104s 14ms/sample - loss:
0.7939 - accuracy: 0.8053

7500/7500 [==============================] - 104s 14ms/sample - loss:
0.7928 - accuracy: 0.8093
Copy to clipboard

Scale-out Across Servers

Scale-out Using AWS DL1/Host NICs

The training model and the corresponding scripts are available in the TensorFlow Hello World Example on GitHub.

A separate script to run a simple example of scale-out using host NICs is provided here. You must append the IP addresses of two servers at the end of the script and the addresses should be separated by whitespaces, similar to the example below:

$ ./run_hvd_16gaudi_hostnic.sh 192.168.0.1 192.168.0.2
Copy to clipboard

The script sets the environment variable HOROVOD_HIERARCHICAL_ALLREDUCE to 1 and invokes a command similar to the below example:

$ mpirun --allow-run-as-root \
    --mca btl_tcp_if_include 192.168.0.1/24,192.168.0.2/24
    --prefix /usr/local/openmpi/
    --host
192.168.0.1,192.168.0.1,192.168.0.1,192.168.0.1,192.168.0.1,192.168.0.1,192.168.0.1,192.168.0.1,19
2.168.0.2,192.168.0.2,192.168.0.2,192.168.0.2,192.168.0.2,192.168.0.2,192.168.0.2,192.168.0.2 \
    -x GC_KERNEL_PATH \
    -x HABANA_LOGS \
    -x TF_MODULES_RELEASE_BUILD \
    -x OPAL_PREFIX \
    -x PYTHONPATH \
    -x LD_LIBRARY_PATH \
    -x PATH \
    ...
    -x HOROVOD_HIERARCHICAL_ALLREDUCE \
    python3 example_hvd.py
Copy to clipboard

The port listened by the ssh server might be different if the workload is not running inside the container. You can specify the port of the remote ssh server using the SSHD_PORT environment variable.

$ SSHD_PORT=22 ./run_hvd_16gaudi_hostnic.sh 192.168.0.1 192.168.0.2
Copy to clipboard

Scale-out Using Gaudi NICs

The training model and the corresponding scripts are available in the TensorFlow Hello World Example on GitHub.

Changing the model to run across multiple servers is not required. The script, however, requires some changes.

A new script, run_hvd_16gaudi.sh is provided here as an example of two servers. The scale-out ports of the Gaudi devices in one server are connected to those in another server through a switch.

Run the script using the below command. You must append the IP addresses of two servers at the end of the script and the addresses should be separated by whitespaces, similar to the example below:

$ ./run_hvd_16gaudi.sh 192.168.0.1 192.168.0.2
Copy to clipboard

By default, the shell scripts connect to port 3022, however, the port listened by the SSH server may differ between different environments. If your environment requires specifying a different port of the remote SSH server, you can use the SSHD_PORT environment variable.

The example below uses port 22:

$ SSHD_PORT=22 ./run_hvd_16gaudi.sh 192.168.0.1 192.168.0.2
Copy to clipboard

To change the port, use the below command. Make sure to set the port to 22 as in the below example.

$ /etc/init.d/ssh restart '-p 22'
Copy to clipboard

Integrating Horovod with ResNet-50 Model Example

ResNet-50 model references can be found in the TensorFlow Model Reference GitHub page. The below steps provide an example of integrating Horovod into a Keras ResNet Model .

  1. General sharding of ImageNet dataset can be found in imagenet_preprocessing.py:

try:
  import horovod.tensorflow as hvd
except ImportError:
  hvd = None

if hvd is not None and hvd.is_initialized() and (is_training or use_distributed_eval):
  logging.info(
    'HVD sharding the dataset: input_pipeline_id=%d num_input_pipelines=%d',
    hvd.rank(), hvd.size())
  dataset = dataset.shard(hvd.size(), hvd.rank())
Copy to clipboard

Note

In the example code, there is an assumption that this import may fail. This is done in order to not enforce artificial dependency on Horovod in single Gaudi or TensorFlow Distributed runs. Sharding, in this example, is conditional and requires two things:

  1. For Horovod import to succeed - hvd is not None.

  2. For Horovod to already be initialized before within this process - hvd.is_initialized().

  1. Define the use_horovod flag located in common.py. The default value is false:

flags.DEFINE_boolean("use_horovod", default=False, help="Use horovod")
Copy to clipboard
  1. Import horovod functions to be called in file common.py:

try:
  import horovod.tensorflow as hvd
except ImportError:
  hvd = None
Copy to clipboard
  1. Calculate the global batch size based on the batch size per card and total card number in common.py:

if hvd is not None and hvd.is_initialized():
  adjusted_batch_size = batch_size * hvd_size()
Copy to clipboard
  1. Import horovod functions to be called in resnet_ctl_imagenet_main.py:

try:
  import horovod.tensorflow as hvd
except ImportError as e:
  _hvd_exc = e
  hvd = None
Copy to clipboard

Note

ImportError is stored to be raised later and only in case you set the use_horovod flag.

  1. Calculate the global batch size based on the batch size per card and total card number and name model directory according to rank ID in file resnet_ctl_imagenet_main.py:

if hvd is not None and hvd.is_initialized():
  batch_size = adjust_batch_size(flags_obj.batch_size)
  model_dir = os.path.join(flags_obj.model_dir, "worker_" + str(hvd.rank()))
Copy to clipboard
  1. Initialize horovod in resnet_ctl_imagenet_main.py:

if flags.FLAGS.use_horovod:
  if hvd is None:
    logging.error("Problem encountered during Horovod import. Please make sure that habana-horovod package is installed.")
    raise _hvd_exc

  hvd.init()
Copy to clipboard
  1. Import horovod functions to be called in resnet_runnable.py:

try:
  import horovod.tensorflow as hvd
except ImportError:
  hvd = None
Copy to clipboard
if self.flags_obj.use_distributed_eval and hvd is not None and hvd.is_initialized():
    test_accuracy = hvd.allreduce(self.test_accuracy.result())
Copy to clipboard