diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-09-09 20:42:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-09 20:46:40 -0700 |
commit | 17a34ab8f214cd1f07d63ea238eda4ba3bf052c5 (patch) | |
tree | 931b7064ab59a801d5ac2a1f4d392ca9cc3b54f5 /tensorflow/contrib/distribute | |
parent | 231f34e3d8634ae02dae00af89d0ceafb3ada588 (diff) |
Add support for numpy arrays with DistributionStrategy in Keras.
PiperOrigin-RevId: 212210810
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/keras_test.py | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py index d46f0eb276..9e1762d92c 100644 --- a/tensorflow/contrib/distribute/python/keras_test.py +++ b/tensorflow/contrib/distribute/python/keras_test.py @@ -237,6 +237,40 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase): distributed_training_utils.validate_distributed_dataset_inputs( strategy, x, y) + def test_calling_model_with_numpy_arrays(self): + with self.cached_session(): + x = keras.layers.Input(shape=(3,), name='input') + y = keras.layers.Dense(4, name='dense')(x) + model = keras.Model(x, y) + + optimizer = gradient_descent.GradientDescentOptimizer(0.001) + loss = 'mse' + metrics = ['mae'] + strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1', + '/device:GPU:0']) + model.compile(optimizer, loss, metrics=metrics, distribute=strategy) + + inputs = np.zeros((64, 3), dtype=np.float32) + targets = np.zeros((64, 4), dtype=np.float32) + + # Call fit with validation data + model.fit(inputs, targets, epochs=1, batch_size=2, verbose=0, + validation_data=(inputs, targets)) + + # TODO(anjalisridhar): We need tests for when the batch size and steps are + # smaller and results in a 0 batch_size and steps value. + model.evaluate(inputs, targets) + # with steps + model.evaluate(inputs, targets, steps=2) + # with batch_size + model.evaluate(inputs, targets, batch_size=8) + + model.predict(inputs) + # with steps + model.predict(inputs, steps=2) + # with batch_size + model.predict(inputs, batch_size=8) + @combinations.generate(all_combinations()) def test_calling_model_on_same_dataset(self, distribution): with self.cached_session(): |