diff options
-rw-r--r-- | tensorflow/python/keras/engine/sequential_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 2 |
2 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 8744503632..51db7ded1d 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -341,6 +341,18 @@ class TestSequentialEagerIntegration(test.TestCase): y = np.random.random((2, 5)) model.fit(x, y, epochs=1) + @tf_test_util.run_in_graph_and_eager_modes + def test_build_before_fit(self): + # Fix for b/112433577 + model = _get_small_mlp(4, 5) + model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3)) + + model.build((None, 6)) + + x = np.random.random((2, 6)) + y = np.random.random((2, 5)) + model.fit(x, y, epochs=1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 1f300e3a5f..ac6b8e295b 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -858,7 +858,7 @@ class Model(Network): all_inputs = [] is_build_called = False is_compile_called = False - if not self.built: + if not self.inputs: # We need to use `x` to set the model inputs. # We type-check that `x` and `y` are either single arrays # or lists of arrays. |