Ask questionsSimple keras model, Model.fit() does not learn unless experimental_run_tf_function=False at compile
Describe the current behavior
Simple Model (see below), compiles without error. It is used in a reinforcement scenario, i.e. iterations of
fit() calls to iteratively train the model.
Currently the model does not seem to improve with calls to
compile() was called with
compile(..., experimental_run_tf_function = False).
Describe the expected behavior
Model should train equally well whether
experimental_run_tf_function = False was passed to
model.compile() or not.
Code to reproduce the issue
Example code can be found at https://github.com/fcarsten/tic-tac-toe/blob/tf-2.1-issue/test_nn_q_tf2.py - need to check out the whole branch as it uses other code from this repository.
Either run as is
run_test(True) to see failing, or
run_test(False) to see how it should run.
When running as expected "Player 1 win %" should increase and end up over 80%, usually somewhere around 90%. When not running as expected, "Player 1 win %" will randomly meander up and down.
The model is defined in File SimpleNNQPlayerTF2.py line 29 ff:
input_layer = tf.keras.Input(shape=(BOARD_SIZE * 3,)) x = tf.keras.layers.Dense(BOARD_SIZE * 3 * 9, activation='relu')(input_layer) x = tf.keras.layers.Dense(BOARD_SIZE * 3 * 100, activation='relu')(x) x = tf.keras.layers.Dense(BOARD_SIZE * 3 * 9, activation='relu')(x) q_values = tf.keras.layers.Dense(BOARD_SIZE, activation=None, name='q_values')(x) probabilities = tf.keras.layers.Softmax(name='probabilities')(q_values) self.model = tf.keras.Model(inputs=input_layer, outputs=[probabilities, q_values]) if run_tf_function: self.model.compile(optimizer='adam', loss = [None, tf.keras.losses.MeanSquaredError()]) else: self.model.compile(optimizer='adam', loss = [None, tf.keras.losses.MeanSquaredError()], experimental_run_tf_function = False)
Other info / logs
While the model is a very simple sequential model, note that it has 2 output layers and the training target layer is not the final layer in the model. Not sure this makes any difference,
Answer questions pavithrasv
Thank you @fcarsten for taking a look. We are doing a bunch of code refactoring internally and i think this use case in eager will be fixed as part of that. Will update this thread after I know when it is fixed in a tf-nightly release.