aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-10-02 14:30:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 14:35:06 -0700
commitc921e45bccac86ce0becc71cedc3da2c702d5c38 (patch)
tree0a460ab691dd66600bdfee5ecfd68c0666bb7095 /tensorflow/contrib/distribute
parente45c90f0e4d17ac22048a73f1e81bd9c7a7a5145 (diff)
Add support for multiple input/output numpy arrays when using Keras APIs.
PiperOrigin-RevId: 215459075
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/BUILD1
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py88
2 files changed, 75 insertions, 14 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index cfb9d42a6f..defa82f98a 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -728,6 +728,7 @@ cuda_py_test(
additional_deps = [
":keras_test_lib",
],
+ shard_count = 16,
tags = [
"multi_and_single_gpu",
"no_pip",
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 3aab2c521f..993cb2bac3 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -189,6 +189,14 @@ def get_dataset(distribution):
return dataset
+def get_predict_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
strategies = [combinations.default_strategy,
combinations.one_device_strategy,
combinations.mirrored_strategy_with_gpu_and_cpu,
@@ -387,16 +395,26 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_with_numpy_arrays(self):
+ # TODO(anjalisridhar): Move this test along with other numpy related tests to
+ # its own class.
+ @combinations.generate(strategy_combinations())
+ def test_creating_var_with_numpy_arrays(self, distribution):
+ with self.cached_session():
+ x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ var_x = distributed_training_utils.get_var_for_numpy(distribution, x)
+ val = self.evaluate(var_x.value())
+ # Verify that the numpy value is copied to the variable.
+ self.assertAllEqual(x, val)
+
+ @combinations.generate(strategy_combinations())
+ def test_calling_model_with_numpy_arrays(self, distribution):
with self.cached_session():
model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae', keras.metrics.CategoricalAccuracy()]
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
@@ -420,6 +438,48 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.predict(inputs, batch_size=8)
@combinations.generate(strategy_combinations())
+ def test_calling_model_with_nested_numpy_arrays(self, distribution):
+ with self.cached_session():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ inputs = [input_a_np, input_b_np]
+
+ output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ targets = [output_d_np, output_e_np]
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0)
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
model = get_model()
@@ -436,7 +496,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
validation_data=dataset, validation_steps=2)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- model.predict(dataset, steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
# TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
# as clone_model's input_tensors argument only seems to accept list and not
@@ -496,10 +556,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
- # Test with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate(strategy_and_optimizer_combinations())
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
@@ -513,7 +570,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
def test_unsupported_features(self):
with self.cached_session():
@@ -726,8 +783,12 @@ class NormalizationLayerWithDistributionStrategyTest(
dataset = dataset.repeat(100)
dataset = batch_wrapper(dataset, 32, distribution)
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
+ predict_dataset = predict_dataset.repeat(100)
+ predict_dataset = batch_wrapper(predict_dataset, 32, distribution)
+
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
- out = model.predict(dataset, steps=2)
+ out = model.predict(predict_dataset, steps=2)
out -= keras.backend.eval(norm.beta)
out /= keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
@@ -811,8 +872,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
predict_batch_size = 4
if with_distribution:
predict_batch_size //= with_distribution.num_towers
- predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
predict_dataset = batch_wrapper(predict_dataset,
predict_batch_size, distribution)
predict_result = model.predict(predict_dataset, steps=1)