aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-09-09 20:42:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-09 20:46:40 -0700
commit17a34ab8f214cd1f07d63ea238eda4ba3bf052c5 (patch)
tree931b7064ab59a801d5ac2a1f4d392ca9cc3b54f5 /tensorflow/contrib/distribute
parent231f34e3d8634ae02dae00af89d0ceafb3ada588 (diff)
Add support for numpy arrays with DistributionStrategy in Keras.
PiperOrigin-RevId: 212210810
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index d46f0eb276..9e1762d92c 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -237,6 +237,40 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
+ def test_calling_model_with_numpy_arrays(self):
+ with self.cached_session():
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae']
+ strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
+ '/device:GPU:0'])
+ model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+
+ inputs = np.zeros((64, 3), dtype=np.float32)
+ targets = np.zeros((64, 4), dtype=np.float32)
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0,
+ validation_data=(inputs, targets))
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
@combinations.generate(all_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():