aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/keras/engine/sequential_test.py12
-rw-r--r--tensorflow/python/keras/engine/training.py2
2 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py
index 8744503632..51db7ded1d 100644
--- a/tensorflow/python/keras/engine/sequential_test.py
+++ b/tensorflow/python/keras/engine/sequential_test.py
@@ -341,6 +341,18 @@ class TestSequentialEagerIntegration(test.TestCase):
y = np.random.random((2, 5))
model.fit(x, y, epochs=1)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_build_before_fit(self):
+ # Fix for b/112433577
+ model = _get_small_mlp(4, 5)
+ model.compile(loss='mse', optimizer=rmsprop.RMSPropOptimizer(1e-3))
+
+ model.build((None, 6))
+
+ x = np.random.random((2, 6))
+ y = np.random.random((2, 5))
+ model.fit(x, y, epochs=1)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 1f300e3a5f..ac6b8e295b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -858,7 +858,7 @@ class Model(Network):
all_inputs = []
is_build_called = False
is_compile_called = False
- if not self.built:
+ if not self.inputs:
# We need to use `x` to set the model inputs.
# We type-check that `x` and `y` are either single arrays
# or lists of arrays.