aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-09-09 19:49:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-09 19:53:47 -0700
commit231f34e3d8634ae02dae00af89d0ceafb3ada588 (patch)
treea41ab792ff500361a8bfdccb9ba62adff651a48b /tensorflow/contrib/distribute
parent551123bba011a50a925f27b5f22c49c898bcd978 (diff)
Add support for evaluate and predict in keras with TPUStrategy. Also add unittests and updated examples.
PiperOrigin-RevId: 212207760
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD21
-rw-r--r--tensorflow/contrib/distribute/python/combinations.py4
-rw-r--r--tensorflow/contrib/distribute/python/examples/keras_mnist.py1
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py142
4 files changed, 107 insertions, 61 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index c524d8b394..87f76eaa94 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -708,19 +708,32 @@ cuda_py_test(
],
)
-cuda_py_test(
- name = "keras_test",
+py_library(
+ name = "keras_test_lib",
+ testonly = 1,
srcs = ["keras_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
+ deps = [
+ ":combinations",
"//tensorflow/contrib/distribute/python:mirrored_strategy",
+ "//tensorflow/contrib/distribute/python:tpu_strategy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python/estimator:estimator_py",
"//tensorflow/python/keras",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+cuda_py_test(
+ name = "keras_test",
+ srcs = ["keras_test.py"],
+ additional_deps = [
+ ":keras_test_lib",
],
tags = [
"multi_and_single_gpu",
+ "no_pip",
"no_windows_gpu",
"notsan",
],
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 2301ba9233..1133be6d0b 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -328,6 +328,10 @@ tpu_strategy = NamedDistribution(
"TPU", lambda: tpu_lib.TPUStrategy(
TPUClusterResolver(""), steps_per_run=5),
required_tpu=True)
+tpu_strategy_one_step = NamedDistribution(
+ "TPU", lambda: tpu_lib.TPUStrategy(
+ TPUClusterResolver(""), steps_per_run=1),
+ required_tpu=True)
# Note that we disable prefetching for testing since prefetching makes
# the input non-deterministic.
mirrored_strategy_with_gpu_and_cpu = NamedDistribution(
diff --git a/tensorflow/contrib/distribute/python/examples/keras_mnist.py b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
index 0495134636..a84ef04196 100644
--- a/tensorflow/contrib/distribute/python/examples/keras_mnist.py
+++ b/tensorflow/contrib/distribute/python/examples/keras_mnist.py
@@ -63,7 +63,6 @@ def get_input_datasets():
# eval dataset
eval_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test))
eval_ds = eval_ds.repeat()
- eval_ds = eval_ds.shuffle(100)
eval_ds = eval_ds.batch(64, drop_remainder=True)
return train_ds, eval_ds, input_shape
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 3cee3e37a7..d46f0eb276 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -18,9 +18,12 @@ from __future__ import division
from __future__ import print_function
import os
+from absl.testing import parameterized
import numpy as np
+from tensorflow.contrib.distribute.python import combinations
from tensorflow.contrib.distribute.python import mirrored_strategy
+from tensorflow.contrib.distribute.python import tpu_strategy
from tensorflow.contrib.distribute.python import values
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
@@ -91,6 +94,25 @@ def get_ds_test_input_fn():
return dataset
+def batch_wrapper(dataset, batch_size, distribution):
+ # TPUs currently require fully defined input shapes, drop_remainder ensures
+ # the input will have fully defined shapes.
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ return dataset.batch(batch_size, drop_remainder=True)
+ else:
+ return dataset.batch(batch_size)
+
+
+def all_combinations():
+ return combinations.combine(
+ distribution=[combinations.default_strategy,
+ combinations.one_device_strategy,
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus,
+ combinations.tpu_strategy_one_step],
+ mode=['graph'])
+
+
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
def setUp(self):
@@ -175,7 +197,7 @@ class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase):
gfile.DeleteRecursively(self._config.model_dir)
-class TestWithDistributionStrategy(test.TestCase):
+class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
def test_validating_dataset_input_tensors_with_shape_mismatch(self):
with self.cached_session():
@@ -215,7 +237,8 @@ class TestWithDistributionStrategy(test.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_on_same_dataset(self):
+ @combinations.generate(all_combinations())
+ def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
@@ -224,15 +247,13 @@ class TestWithDistributionStrategy(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
# Call fit with validation data
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
@@ -241,6 +262,9 @@ class TestWithDistributionStrategy(test.TestCase):
validation_data=dataset, validation_steps=2)
model.predict(dataset, steps=2)
+ # TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
+ # as clone_model's input_tensors argument only seems to accept list and not
+ # tuples or dict.
def test_fit_with_tuple_and_dict_dataset_inputs(self):
with self.cached_session():
a = keras.layers.Input(shape=(3,), name='input_a')
@@ -282,7 +306,8 @@ class TestWithDistributionStrategy(test.TestCase):
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
- def test_fit_eval_and_predict_methods_on_dataset(self):
+ @combinations.generate(all_combinations())
+ def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
x = keras.layers.Input(shape=(3,), name='input')
y = keras.layers.Dense(4, name='dense')(x)
@@ -291,16 +316,13 @@ class TestWithDistributionStrategy(test.TestCase):
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
metrics = ['mae']
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:0',
- '/device:CPU:0'])
-
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((10, 3), dtype=np.float32)
targets = np.zeros((10, 4), dtype=np.float32)
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.repeat(100)
- dataset = dataset.batch(10)
+ dataset = batch_wrapper(dataset, 10, distribution)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
@@ -496,6 +518,8 @@ class TestWithDistributionStrategy(test.TestCase):
class LossMaskingWithDistributionStrategyTest(test.TestCase):
+ # TODO(priyag): Enable all strategies for this test. Currently it does not
+ # work for TPU due to some invalid datatype.
def test_masking(self):
with self.cached_session():
np.random.seed(1337)
@@ -519,24 +543,25 @@ class LossMaskingWithDistributionStrategyTest(test.TestCase):
self.assertEqual(hist.history['loss'][0], 0)
-class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
+class NormalizationLayerWithDistributionStrategyTest(
+ test.TestCase, parameterized.TestCase):
- def test_batchnorm_correctness(self):
+ @combinations.generate(all_combinations())
+ def test_batchnorm_correctness(self, distribution):
with self.cached_session():
model = keras.models.Sequential()
norm = keras.layers.BatchNormalization(input_shape=(10,), momentum=0.8)
model.add(norm)
- strategy = mirrored_strategy.MirroredStrategy(['/device:CPU:0',
- '/device:GPU:0'])
model.compile(loss='mse',
optimizer=gradient_descent.GradientDescentOptimizer(0.01),
- distribute=strategy)
+ distribute=distribution)
# centered on 5.0, variance 10.0
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10))
+ x = x.astype('float32')
dataset = dataset_ops.Dataset.from_tensor_slices((x, x))
dataset = dataset.repeat(100)
- dataset = dataset.batch(32)
+ dataset = batch_wrapper(dataset, 32, distribution)
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
out = model.predict(dataset, steps=2)
@@ -546,9 +571,11 @@ class NormalizationLayerWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
-class CorrectnessWithDistributionStrategyTest(test.TestCase):
+class CorrectnessWithDistributionStrategyTest(test.TestCase,
+ parameterized.TestCase):
- def test_correctness(self):
+ @combinations.generate(all_combinations())
+ def test_correctness(self, distribution):
with self.cached_session():
keras.backend.set_image_data_format('channels_last')
num_samples = 10000
@@ -557,43 +584,43 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
x_train = x_train.astype('float32')
y_train = y_train.astype('float32')
- model = keras.Sequential()
- model.add(keras.layers.Dense(1, input_shape=(1,)))
-
- # With DistributionStrategy
- dataset_with = dataset_ops.Dataset.from_tensor_slices((x_train, y_train))
- dataset_with = dataset_with.batch(32)
- strategy = mirrored_strategy.MirroredStrategy(devices=['/device:CPU:0',
- '/device:GPU:0'])
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5),
- distribute=strategy)
- model.fit(x=dataset_with, epochs=1, steps_per_epoch=310)
- wts_with_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_with = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
- predict_dataset_with = predict_dataset_with.batch(2)
- predict_with_ds = model.predict(predict_dataset_with, steps=1)
- predict_with_ds = np.reshape(predict_with_ds, (4, 1))
-
- # Without DistributionStrategy
- dataset_without = dataset_ops.Dataset.from_tensor_slices((x_train,
+ def fit_and_predict(with_distribution=None):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(1, input_shape=(1,)))
+ model.compile(
+ loss=keras.losses.mean_squared_error,
+ optimizer=gradient_descent.GradientDescentOptimizer(0.5),
+ distribute=with_distribution)
+
+ batch_size = 64
+ if with_distribution:
+ batch_size //= with_distribution.num_towers
+ train_dataset = dataset_ops.Dataset.from_tensor_slices((x_train,
y_train))
- dataset_without = dataset_without.batch(64)
-
- model.compile(loss=keras.losses.mean_squared_error,
- optimizer=gradient_descent.GradientDescentOptimizer(0.5))
- model.fit(x=dataset_without, epochs=1, steps_per_epoch=310)
- wts_without_ds = model.get_weights()
-
- x_predict = [[1], [2], [3], [4]]
- predict_dataset_without = dataset_ops.Dataset.from_tensor_slices((
- x_predict, x_predict))
- predict_dataset_without = predict_dataset_without.batch(4)
- predict_without_ds = model.predict(predict_dataset_without, steps=1)
+ train_dataset = batch_wrapper(train_dataset, batch_size, distribution)
+ # Running only 100 steps instead of the full dataset to keep test
+ # duration small.
+ model.fit(x=train_dataset, epochs=1, steps_per_epoch=100)
+
+ weights = model.get_weights()
+
+ x_predict = [[1.], [2.], [3.], [4.]]
+ predict_batch_size = 4
+ if with_distribution:
+ predict_batch_size //= with_distribution.num_towers
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
+ x_predict))
+ predict_dataset = batch_wrapper(predict_dataset,
+ predict_batch_size, distribution)
+ predict_result = model.predict(predict_dataset, steps=1)
+ predict_result = np.reshape(predict_result, (4, 1))
+
+ return weights, predict_result
+
+ wts_with_ds, predict_with_ds = fit_and_predict(
+ with_distribution=distribution)
+ wts_without_ds, predict_without_ds = fit_and_predict(
+ with_distribution=None)
# Verify that the weights are the same within some limits of tolerance.
np.testing.assert_allclose(wts_with_ds[0], wts_without_ds[0], rtol=1e-3)
@@ -602,5 +629,8 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase):
np.testing.assert_allclose(predict_with_ds, predict_without_ds, rtol=1e-3)
+# TODO(priyag): Add a test for TPUStrategy with steps_per_run > 1.
+
+
if __name__ == '__main__':
test.main()