Ask questionsUsing tf.data.Dataset has big overhead
Describe the current behavior
Dataset reduces performance by a small but significant amount, ~7% for ImageNet like data
Describe the expected behavior
Dataset has no or only marginal performance impact
Standalone code to reproduce the issue
import tensorflow as tf from timeit import timeit @tf.function def train_step(x, y): model.train_on_batch(x, y) for useData in (True, False): model = tf.keras.applications.ResNet50(weights=None, classes=1000) model.compile( loss=tf.losses.SparseCategoricalCrossentropy(), optimizer=tf.keras.optimizers.SGD(), metrics=['accuracy'], experimental_run_tf_function=True) if useData: x = tf.random.uniform([1, 32, 224, 224, 3]) y = tf.random.uniform([1, 32, 1], minval=0, maxval=999, dtype=tf.int64) dataset = tf.data.Dataset.from_tensor_slices((x, y)).repeat() def train(steps): for x, y in dataset.take(steps): train_step(x, y) else: x = tf.random.uniform([32, 224, 224, 3]) y = tf.random.uniform([32, 1], minval=0, maxval=999, dtype=tf.int64) def train(steps): for _ in range(steps): train_step(x, y) # warmup train(2) t = timeit(lambda: train(50), number=10) print('useData: %s -> %s' % (useData, t))
Sample output: useData: True -> 89.92945478390902 useData: False -> 86.73652107780799
For more realistic training loops (e.g. including callbacks) the difference is even bigger. Some of my tests:
constant: total images/sec: 496.47 (calculation(497.53) + preprocessing(1.06)) dataset: total images/sec: 465.09 (calculation(478.64) + preprocessing(13.55))
First number is calculated from training loop execution time (after warmup) the latter only the train-step and the difference (to the first number) which I called "preprocessing" as it is iterating over the dataset (calling next on the iterator by the for loop) and hence dominated by preprocessing functions if present (none here) including the
take Dataset adapters.
So 2 conclusions: Getting elements from the iterator seems to be quite costly (1->13.6) and even the training loop itself gets slower (498 -> 479)
This would be a reason to avoid the dataset API.
Answer questions aaudiber
This is a difficult case for
tf.data.Dataset because there isn't any preprocessing.
tf.data.Dataset usually does preprocessing on the CPU, then transfers the data to the GPU afterward. The
tf.data.Dataset example is slower because it is copying the tensors from GPU memory to CPU memory and back each time, while the non-Dataset example starts with the tensors on the GPU and doesn't need to move them at all since there isn't any preprocessing.
Ideally we could use tf.data.experimental.prefetch_to_device to prefetch to the GPU and recover the performance, but there is currently an outstanding bug with prefetch_to_device. Once that gets fixed, the performance should be almost identical when using