AWS Machine Learning Blog

Multi-GPU and distributed training using Horovod in Amazon SageMaker Pipe mode

There are many techniques to train deep learning models with a small amount of data. Examples include transfer learning, few-shot learning, or even one-shot learning for an image classification task and fine-tuning for language models based on a pre-trained BERT or GPT2 model. However, you may still have a use case in which you need a large amount of training data. For instance, if the images are quite different from ImageNet or your language corpus is domain specific rather than general, then it’s hard to achieve the desired model performance with transfer learning. If you are deep learning researchers, you want to try new ideas or approaches from scratch. In these cases, your task is to train a large deep learning model with a large dataset, which can take days, weeks, or even months if you don’t use the proper methods for training large-scale models.

In this post, I explain how to run multi-GPU training on a single instance on Amazon SageMaker, and discuss efficient multi-GPU and multi-node distributed training on Amazon SageMaker.

Basics on Horovod

When you train a model with a large amount of data, you should distribute the training across multiple GPUs on either a single instance or multiple instances. Deep learning frameworks provide their own methods to support multi-GPU training or distributed training. However, there is another way to accomplish this using distributed deep learning framework such as Horovod. Horovod is Uber’s open-source framework for distributed deep learning, and it’s available for use with most popular deep learning toolkits like TensorFlow, Keras, PyTorch, and Apache MXNet. It uses the all-reduce algorithm for fast distributed training rather than using a parameter server approach, and includes multiple optimization methods to make distributed training faster. For more information, see Meet Horovod: Uber’s Open Source Distributed Deep Learning Framework for TensorFlow.

Preparing your data for Horovod

When you start a training job using Horovod, Horovod launches an independent process for each worker per one GPU in the Horovod cluster. For example, four worker processes start when you run a Horovod training job with one training instance with four GPUs (one Amazon SageMaker ml.p3.8xlarge or Amazon Elastic Compute Cloud (Amazon EC2) p3.8xlarge instance). All four Horovod workers read their own dataset, which is already split into shards as data parallelism. If there are 40,000 training samples, each worker gets 10,000 training samples without duplication. If you use Horovod for distributed training or even multi-GPU training, you should do this data shard preparation beforehand and let the worker read its shard from the file system. (There are deep learning frameworks that do this automatically on the fly, such as PyTorch’s DataParallel and DistributedDataParallel.)

The following diagram illustrates two architectures for storing shards.

You can provide a dataset for an Amazon SageMaker training job in several different ways. One typical method is to store all your dataset in your Amazon Simple Storage Service (Amazon S3) bucket and access them when needed. Although you may use a shared file system like Amazon FSx for Lustre or Amazon Elastic File System (Amazon EFS) for data storage, you can also avoid the additional cost by retrieving data directly from Amazon S3 via two input modes available to Amazon SageMaker: File mode and Pipe mode.

In File mode, when the training job is launched in Amazon SageMaker, the defined dataset is transferred from the specified S3 bucket to training instances, and they are placed in a directory under a certain directory. However, if the dataset is huge, it takes a long time to copy objects from the bucket to the training instances’ storage, and the start of training is delayed until the data transfer is complete. In some cases, this might slow down the machine learning (ML) pipeline, and even slow down innovation or research speed.

You can also access the dataset stored in Amazon S3 directly through Pipe mode. Pipe mode creates a direct input pipe between the training instance and S3 bucket, and allows the training process to access the objects directly without copying it all into training instances before training begins. To access a dataset in a given Amazon S3 URI as Pipe mode, you set the input mode to Pipe when you create an Amazon SageMaker estimator. See the following code:

from sagemaker.tensorflow import TensorFlow

tf_estimator = TensorFlow(entry_point='train.py',
                          role='SageMakerRole',
                          train_instance_type='ml.p3.2xlarge',
                          train_instance_count=2,
                          framework_version='2.1.0',
                          py_version='py3',
                          input_mode='Pipe')

With Pipe mode, the training data is available as a FIFO stream. There is an extension of a TensorFlow dataset that makes it easy to access a streamed dataset. For more information about Pipe mode and TensorFlow, see Accelerate model training using faster Pipe mode on Amazon SageMaker and the Amazon SageMaker TensorFlow extension GitHub repo.

Pipe mode with Horovod

There is a special care needed when you use Horovod with Pipe mode for either multi-GPU training using a single training instance or distributed training using multiple training instances with multiple GPU cores. The following diagram illustrates this architecture.

Pipe mode streams data from Amazon S3 into Unix Named Pipes or FIFOs in the training instances. A FIFO file supports only a single writer/reader pair, and there is one FIFO created for one channel per epoch. Normally, people define one channel for the training dataset and another for the validation or test dataset and pass these input channels to the training job as parameters of Amazon SageMaker estimator’s fit() function. See the following code:

from sagemaker.session import s3_input

input_channel = {'train': s3_input('s3://your-bucket-name/train-dataset/')}

tf_estimator.fit(inputs=input_channel)     

What does this mean in Horovod multi-GPU training? Processes launched by a multi-GPU training job using Horovod compete each other on a single FIFO, which can’t be accessed simultaneously by multiple processes. Because only one worker process can access the FIFO concurrently and it doesn’t release the handle until the training job is finished, all the other workers can’t read data from the same FIFO and therefore the training falls into a deadlock-style infinite loop. If you see repeated messages similar to the following code, this is the problem you are encountering:

[1,0]<stderr>:Stalled ranks:
[1,0]<stderr>:0: [training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_11_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_12_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_14_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_15_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_18_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_19_0 ...]
[1,0]<stderr>:2: [training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_11_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_12_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_14_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_15_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_18_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_19_0 ...]
[1,0]<stderr>:3: [training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_11_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_12_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_14_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_15_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_18_0, training/Adam/DistributedAdam_Allreduce/HorovodAllreduce_training_Adam_gradients_AddN_19_0 ...]

You should shard the dataset in an S3 bucket into the number of GPUs to be used for training. If you have 4,000 TensorFlow record files, and you train a model using one ml.p3.8xlarge with four GPUs, you can place each 1,000 TensorFlow record files under a different prefix, as in the following code:

s3://your-bucket-name/train/0/
s3://your-bucket-name/train/1/
s3://your-bucket-name/train/2/
s3://your-bucket-name/train/3/

When it comes to the distributed training using multiple training instances, you can use the same number of channels as the single instance GPU training case with the help of ShardedByS3Key. You put multiple dataset files into each S3 prefix, and then ShardedByS3Key will distribute the dataset files to the channels. For example, assume there are two dataset files, s3://your-bucket-name/train/0/train-0.tfrecord and  s3://your-bucket-name/train/0/train-1.tfrecord, in s3://your-bucket-name/train/0/ folder. The first training instance will retrieve dataset from s3://your-bucket-name/train/0/train-0.tfrecord and the second training instance will retrieve dataset from s3://your-bucket-name/train/0/train-1.tfrecord.

Therefore, you need to shard the data to have as many shards as the total number of GPU cores in the Horovod cluster, while using the same number of input channels.

You then define four input channels for Amazon SageMaker training. See the following code:

shuffle_config = sagemaker.session.ShuffleConfig(234)
    train_s3_uri_prefix = dataset_location

    remote_inputs = {}

    for idx in range(gpus_per_host):
        train_s3_uri = f'{train_s3_uri_prefix}/train/{idx}/'
        train_s3_input = s3_input(train_s3_uri, shuffle_config=shuffle_config, distribution='ShardedByS3Key')
        remote_inputs[f'train_{idx}'] = train_s3_input

ShuffleConfig makes sure that the order of the files under the Amazon S3 prefix is randomized for every epoch. For more information, see ShuffleConfig.

Use the following channel definition when you call the fit method on the Amazon SageMaker estimator:

tf_estimator.fit(input_channels)

For validation and test tasks, you only run these tasks on a single worker (normally on the primary worker or a worker of Rank 0). You don’t need to have multiple validation or test channels. However, if you use the tf.keras.model.fit() function for training, the training gets stalled if only one Horovod worker does validation (for more information, see issue #600 on the Horovod GitHub repo). If validation is needed with tf.keras.model.fit(), you also have to provide each input channel for the validation dataset to each worker just like the training input channel. See the following code:

validation_s3_uri = 's3://your-bucket-name/validation/'

for idx in range(4):
    validation_s3_input = s3_input(validation_s3_uri)
    input_channels[f'validation_{idx}'] = validation_s3_input
    
eval_s3_uri = 's3://your-bucket-name/eval/'
eval_s3_input = s3_input(eval_s3_uri)
input_channels['eval'] = eval_s3_input

Instead of using the prefix of the S3 bucket, you can use a plain ManifestFile that contains a list of object keys. For more information, see Input Data.

Using the data channel in training code

In the training script, you need to force each Horovod worker process to access its own shard so two workers don’t access the same input channel. In our use case, the names of input channels are defined using indexes starting from 0, so you can use the hvd.rank() function, which gives the cluster-wide unique rank index of the current worker process, and the rank also begins from 0 (see line 13 in the following code). For this post, we use the Amazon SageMaker TensorFlow extension PipeModeDataset. For other deep learning frameworks, read data from a FIFO named /opt/ml/input/data/[channel_name]_${epoch} for each epoch. For more examples, see the GitHub repo.

 1: from sagemaker_tensorflow import PipeModeDataset
 2: 
 3: features = {'data': tf.FixedLenFeature([], tf.string),
 4:             'labels': tf.FixedLenFeature([], tf.int64)}
 5:
 6: def parse(record):
 7:     parsed = tf.parse_single_example(record, features)
 8:     return ({
 9:         'data': tf.decode_raw(parsed['data'], tf.float64)
10:    }, parsed['labels'])
11:
12: # For Horovod and Pipe mode, use the input channel allocated to this worker using rank information
13: channel_name = 'train_{}'.format(hvd.rank())
14:
15: ds = PipeModeDataset(channel=channel_name, record_format='TFRecord')
16: ds = ds.map(parse)
17: ds = ds.batch(64)
18: ds = ds.prefetch(10)

In a Horovod cluster with one or more instances, ranks are uniquely assigned from 0 to the number of total GPUs – 1. You don’t need to worry about the order of instances or rank number as long as you correctly defined the input channel name using indexes from 0.

Monitoring with Tensorboard

For flexible monitoring of the training process, we can invoke Tensorboard from any remote compute instance by first uploading the logs at the end of each epoch to the S3 bucket. To do so, create a callback to push the local log to an S3 bucket path that’s restricted to the primary (rank 0) compute node running on Horovod. See the following code:

class Sync2S3(tf.keras.callbacks.Callback):
    def __init__(self, logdir, s3logdir):
        super(Sync2S3, self).__init__()
        self.logdir = logdir
        self.s3logdir = s3logdir
    
    def on_epoch_end(self, batch, logs={}):
        os.system('aws s3 sync '+self.logdir+' '+self.s3logdir)

...

if hvd.rank() == 0:
    logdir = args.output_data_dir + '/' + datetime.now().strftime("%Y%m%d-%H%M%S")
    callbacks.append(TensorBoard(log_dir=logdir))
    callbacks.append(Sync2S3(logdir=logdir, s3logdir=tensorboard_logs))

With the training logs dumped in the S3 bucket, you can run Tensorboard from any server you like, including an EC2 instance, an Amazon SageMaker notebook instance, or even your local machine, as long as the server hosting Tensorboard has permissions to access the Amazon S3 log object. To launch Tensorboard, run the following shell commands in your terminal. To support direct ingestion of log data from the Amazon S3 source, Tensorboard must be running at or above version 1.14.0. The following command lines use logs located in the S3 bucket in us-east-1:

S3_REGION=us-east-1
tensorboard --logdir s3://{bucket_name}/tensorboard_logs/

If you run the preceding commands in an Amazon SageMaker notebook instance, you can access the running Tensorboard UI at https://<SageMaker-notebook-instance-name>.notebook.<notebook-region>.sagemaker.aws/proxy/6006/.

Cleaning up

After you have explored the distributed training covered in this post, clean up resources that you’re no longer using to avoid additional costs, such as the S3 buckets, FSx for Lustre, and any Amazon SageMaker instances.

Conclusion

Horovod multi-GPU or distributed training on Amazon SageMaker with Pipe mode can perform large-scale training by creating separate training channels for each shard and accessing its own shard in the data pipeline. This benefits training on Amazon SageMaker with a large training dataset by reducing the amount of time to transfer the dataset to the training instances before actual training begins.

For the complete training example to run on Amazon SageMaker, where Pipe mode and Horovod are applied together, see the GitHub repo.


About the Authors

Muhyun Kim is a data scientist at Amazon Machine Learning Solutions Lab. He solves customer’s various business problems by applying machine learning and deep learning, and also helps them gets skilled.

Jiyang Kang is a deep learning architect at Amazon Machine Learning Solutions Lab. With experience designing global enterprise workloads on AWS, he is responsible for designing and implementing ML solutions for customers’ new business problems.

Hussain Karimi is a data scientist at the Maching Learning Solutions Lab where he works with customers across various verticals to initate and build automated, algorithmic models that generate business value.