Media Pipe for TensorFlow ResNet
On this Page
Media Pipe for TensorFlow ResNet¶
This section describes how to define a ResnetPipe
class
and implement a Media pipe API dataloader, HabanaDataset
, for TensorFlow ResNet.
Defining a ResnetPipe Class¶
ResnetPipe
class is derived from MediaPipe
class, as described in Creating and Executing Media Pipeline, implementing Media Pipe API for ResNet.
An implementation of ResnetPipe
class can be found in
resnet_media_pipe.py
in Habana’s Model References GitHub page.
The following set of operations is performed on the data:
Read image using ReadImageDatasetFromDir
Decode image using ImageDecoder
- For training data only:
Random crop as a part of ImageDecoder with RANDOMIZED_AREA_AND_ASPECT_RATIO_CROP as random_crop_type
Resize as a part of ImageDecoder
- For evaluation data only:
Resize as a part of ImageDecoder
Crop using Crop
Random flip implemented using RandomFlip
Subtract mean RGB values using Sub
Transpose using Transpose
API functions are first defined in the ResnetPipe
constructor.
A sequence of operations is then set up in the definegraph
method:
def definegraph(self):
jpegs, data = self.input()
if self.is_training == True:
images = self.decode(jpegs)
random_flip_input = self.random_flip_input()
images = self.random_flip(images, random_flip_input)
else:
images = self.decode(jpegs)
images = self.crop(images)
mean = self.mean_node()
images = self.cast_pre(images)
images = self.sub(images, mean)
if self.out_dtype != dtype.FLOAT32:
images = self.cast_pst(images)
images = self.pst_transp(images)
return images, data
Implementing HabanaDataset¶
With the defined ResnetPipe
, a HabanaDataset
instance can be created.
HabanaDataset is a TensorFlow Dataset implementation that allows data loading via any media pipe implementation.
ResnetPipe
instance is passed to HabanaDataset
constructor along with data shape and type.
The class implementation encapsulates iterating over images and prefetching.
The HabanaDataset
instance is compatible with native TensorFlow Dataset and can be used in TensorFlow ResNet.
This allows reusing TensorFlow ResNet model code without additional changes.
The relevant code is show in imagenet_dataset.py:
def habana_imagenet_dataset(is_training,
jpeg_data_dir,
batch_size,
num_channels,
img_size,
data_type,
use_distributed_eval,
use_pytorch_style_crop=False):
# setting parameters here, code removed for clarity
from TensorFlow.computer_vision.common.resnet_media_pipe import ResnetPipe
pipe = ResnetPipe("hpu", queue_depth, batch_size, num_channels, img_size, img_size, is_training,
data_dir, media_data_type, num_slices, slice_index, crop_type)
from habana_frameworks.tensorflow.media.habana_dataset import HabanaDataset
dataset = HabanaDataset(output_shapes=[(batch_size,
img_size,
img_size,
num_channels),
(batch_size,)],
output_types=[data_type, tf.float32], pipeline=pipe)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
return dataset
Selecting a DataLoader for TensorFlow ResNet¶
In Habana’s TensorFlow ResNet, two DataLoaders are implemented in the model:
HabanaDataset
Native TensorFlow Dataset
HabanaDataset offers better performance than native support and can be used only if the following conditions are met:
habana_media_loader
package is installed.Dataset is given in jpeg format. The current ResnetPipe implementation does not support TFRecord format.
Workload is being run on Gaudi2. First-gen Gaudi supports only native TensorFlow Dataset.
FORCE_HABANA_IMAGENET_LOADER_FALLBACK
environment variable is NOT set to 1.
The above conditions are covered by media_loader_can_be_used
function in
imagenet_dataset.py.
Choosing and creating a proper dataset is handled in input_fn
located in
imagenet_preprocessing.py.
See the following example (function parameters are skipped for readability):
def input_fn(...):
if imagenet_dataset.media_loader_can_be_used(jpeg_data_dir) is True:
return imagenet_dataset.habana_imagenet_dataset(...)
else:
return imagenet_dataset_fallback(...)
If any errors are encountered when building imagenet_dataset_fallback
,
fallback to the native TensorFlow Dataset will be attempted.