Ask questionsImageAugmentation using tf.keras.preprocessing.image.ImageDataGenerator and tf.datasets: is running infinitely

What I need help with / What I was wondering I am facing issue while running the fit() function in TensorFlow(v 2.2.0-rc4) with augmented images(using ImageDataGenerator) passed as a dataset. The fit() function is running infinitely without stopping.

What I've tried so far I tried it with the default code which was shared in Tensorflow documentation.

Please find the code snippet below:

import tensorflow as tf from tensorflow.keras.preprocessing.image import ImageDataGenerator from tensorflow.keras.models import Sequential, Model from tensorflow.keras.layers import Dense, Dropout, Flatten from tensorflow.keras.layers import Conv2D, MaxPooling2D from tensorflow.keras.layers import Input, Dense

flowers = tf.keras.utils.get_file( 'flower_photos', '', untar=True)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255, rotation_range=20)

images, labels = next(img_gen.flow_from_directory(flowers))

print(images.dtype, images.shape) print(labels.dtype, labels.shape)

train_data_gen = img_gen.flow_from_directory( batch_size=32, directory=flowers, shuffle=True, target_size=(256, 256), class_mode='categorical')

ds = train_data_gen, output_types=(tf.float32, tf.float32), output_shapes=([32, 256, 256, 3], [32, 5]) )

ds = ds.prefetch(

it = iter(ds) batch = next(it) print(batch)

def create_model(): model = Sequential() model.add(Conv2D(32, (3, 3), activation='relu', input_shape=images[0].shape)) model.add(Conv2D(32, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.5)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.5)) model.add(Flatten()) model.add(Dense(64, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(5, activation='softmax')) return model

model = create_model() model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=["accuracy"]), verbose=1, batch_size= 32, epochs =1)

This last line of code - fit() is running infinitly without stopping. I had also tried passing steps_per_epoch = total_no_of_train_records/batch_size.

It would be nice if... I would like you to confirm whethere this is a bug in the tensorflow datasets package and in which release will this be fixed.

Environment information

  • System: Google colaborator
  • Python version: v3.6.9
  • `tensorflow version: v2.2.0-rc4

Answer questions tomerk

The ImageDataGenerator returns an infinite number of values, so the epoch would never end unless you specify steps_per_epoch. The examples in ImageDataGenerator briefly mention this, but I think the documentation should be clearer about this. So I'll go ahead and update that.

I'm not sure why setting steps_per_epoch didn't work for you? When I set a value for steps_per_epoch in this colab it ends once it has completed that many steps. Perhaps it's just being slow? There are ~114 steps in one epoch over the dataset (3670 training images / 32 batch size). 114 * 7 sec / step (in this colab) is close to 15 minutes of runtime.


Related questions

ModuleNotFoundError: No module named 'tensorflow.contrib' hot 8
Error occurred when finalizing GeneratorDataset iterator hot 6
ModuleNotFoundError: No module named 'tensorflow.contrib'
When importing TensorFlow, error loading Hadoop
tf.keras.layers.Conv1DTranspose ?
tensorflow-gpu CUPTI errors hot 4
[TF 2.0] tf.keras.optimizers.Adam hot 4
Lossy conversion from float32 to uint8. Range [0, 1]. Convert image to uint8 prior to saving to suppress this warning. hot 4
TF2.0 AutoGraph issue hot 4
Tf.Keras metrics issue hot 4
module 'tensorflow' has no attribute 'ConfigProto' hot 4
TF 2.0 'Tensor' object has no attribute 'numpy' while using .numpy() although eager execution enabled by default hot 4
ModuleNotFoundError: No module named 'tensorflow.examples.tutorials' hot 4
AttributeError: module 'tensorflow.python.framework.op_def_registry' has no attribute 'register_op_list' hot 4
tensorflow2.0 detected 'xla_gpu' , but 'gpu' expected hot 3
Github User Rank List