aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/multi_gpu_utils_test.py
diff options
context:
space:
mode:
authorGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-28 16:12:57 -0700
committerGravatar avijit-nervana <avijit.chakraborty@intel.com>2018-08-28 16:12:57 -0700
commit757538bd14f24de3d7bf654a03c6543bb06a8e75 (patch)
tree4873885feca3a3e5787241477ef8d1333c494d1e /tensorflow/python/keras/utils/multi_gpu_utils_test.py
parent6b25c37daaa6a063b6b687252343db5453a84b8b (diff)
parent7f52de1a2b03568dc98ad51685b56661a5105da6 (diff)
Merge branch 'master' into avijit/add-cpu-backend
Diffstat (limited to 'tensorflow/python/keras/utils/multi_gpu_utils_test.py')
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 77792d14f5..c7e94998b4 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -180,6 +180,23 @@ class TestMultiGPUModel(test.TestCase):
target_tensors=[targets])
parallel_model.fit(epochs=1, steps_per_epoch=3)
+ def test_multi_gpu_with_multi_input_layers(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.test_session():
+ inputs = keras.Input((4, 3))
+ init_state = keras.Input((3,))
+ outputs = keras.layers.SimpleRNN(
+ 3, return_sequences=True)(inputs, initial_state=init_state)
+ x = [np.random.randn(2, 4, 3), np.random.randn(2, 3)]
+ y = np.random.randn(2, 4, 3)
+ model = keras.Model([inputs, init_state], outputs)
+ parallel_model = keras.utils.multi_gpu_model(model, gpus=gpus)
+ parallel_model.compile(loss='mean_squared_error', optimizer='adam')
+ parallel_model.train_on_batch(x, y)
if __name__ == '__main__':
test.main()