aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-09-24 20:22:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 20:29:54 -0700
commit6ba60e051409a5346c2aab21160c9c311de1cb03 (patch)
tree955be96a46d13601582343a25ae3612ad53179d7 /tensorflow/contrib/distribute
parent4dc77744ff6a6854cf4aa2934eb4501bc22c3465 (diff)
Add validation that input shapes should be fully defined when using TPU strategy with keras.
PiperOrigin-RevId: 214376435
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py23
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py2
2 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 8165a70743..2e6cd43fd4 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -635,6 +635,29 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
'expected input to have shape'):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+ @combinations.generate(combinations.combine(
+ distribution=[combinations.tpu_strategy_one_step],
+ mode=['graph']))
+ def test_dataset_input_shape_fully_defined(self, distribution):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = rmsprop.RMSPropOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ targets = np.zeros((10, 4), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ # Input shapes are not fully known. Batch dimension is unknown as we are
+ # not using the drop_remainder argument.
+ dataset = dataset.repeat(100).batch(10)
+
+ with self.assertRaisesRegexp(ValueError, 'requires fully defined shapes'):
+ model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0)
+
def test_learning_phase_value(self):
# TODO(anjalisridhar): Modify this test to use Lambdas since we can compare
# meaningful values. Currently we don't pass the learning phase if the
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index ba2cc2e806..a6762e5e87 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -158,7 +158,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
raise ValueError(
'TPU currently requires fully defined shapes. Either use '
'set_shape() on the input tensors or use '
- 'dataset.apply(map_and_batch(..., drop_remainder=True)).')
+ 'dataset.batch(..., drop_remainder=True).')
types = nest.flatten(iterator.output_types)
enqueue_ops = [