aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-10 15:21:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 15:28:15 -0700
commit274e784585460a74b28297dc1e2febfb262457dd (patch)
tree1055ec3184fce0b4afd45575c755567e64db626e /tensorflow/contrib/eager
parent07e33d027f65c0c3e301ba9befbe82fc78cb3fa6 (diff)
Add tests for remote eager execution.
PiperOrigin-RevId: 208277185
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py178
2 files changed, 192 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 48b6f4acb2..f7933639a0 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -223,3 +223,17 @@ py_test(
"//tensorflow/python/eager:test",
],
)
+
+py_test(
+ name = "remote_test",
+ srcs = ["remote_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/eager:function",
+ ],
+)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
new file mode 100644
index 0000000000..76f48eeb1c
--- /dev/null
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -0,0 +1,178 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for remote eager execution."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import functools
+import os
+
+import numpy as np
+
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.core.protobuf import tensorflow_server_pb2
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import server_lib
+
+JOB_NAME = "remote_device"
+ALT_JOB_NAME = "alt_remote_device"
+
+
+def run_sync_and_async(f):
+ """Execute all test methods in the given class in sync and async modes."""
+
+ @functools.wraps(f)
+ def decorator(self, *args, **kwargs):
+ with context.execution_mode(context.ASYNC):
+ f(self, *args, **kwargs)
+
+ with context.execution_mode(context.SYNC):
+ f(self, *args, **kwargs)
+
+ return decorator
+
+
+def get_server_def(job_name, local_server_port, remote_server_addresses,
+ task_index):
+ """Returns a server def with a single job + multiple tasks."""
+ cluster_def = cluster_pb2.ClusterDef()
+ job_def = cluster_def.job.add()
+ job_def.name = job_name
+ job_def.tasks[0] = "localhost:%d" % local_server_port
+
+ for i, remote_server_address in enumerate(remote_server_addresses, start=1):
+ job_def.tasks[i] = remote_server_address
+
+ server_def = tensorflow_server_pb2.ServerDef(
+ cluster=cluster_def,
+ job_name=job_name,
+ task_index=task_index,
+ protocol="grpc")
+
+ return server_def
+
+
+class RemoteExecutionTest(test.TestCase):
+
+ def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
+ super(RemoteExecutionTest, self).__init__(methodName)
+ self._cached_server1 = server_lib.Server.create_local_server()
+ self._cached_server2 = server_lib.Server.create_local_server()
+
+ os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
+
+ self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
+ self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
+
+ # Start the local server.
+ context.set_server_def(
+ server_def=get_server_def(
+ JOB_NAME,
+ local_server_port=0,
+ remote_server_addresses=[
+ self._cached_server1_target, self._cached_server2_target
+ ],
+ task_index=0))
+
+ @run_sync_and_async
+ def testDefunMatmul(self):
+ """Basic remote eager execution with defun."""
+
+ mm_defun = function.defun(math_ops.matmul)
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ x1 = array_ops.ones([2, 2])
+ with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
+ x2 = array_ops.ones([2, 2])
+ y = mm_defun(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
+ @run_sync_and_async
+ def testSimpleMatmul(self):
+ """Basic remote eager execution."""
+
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ x1 = array_ops.ones([2, 2])
+ with ops.device("job:%s/replica:0/task:2/device:CPU:0" % JOB_NAME):
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
+ @run_sync_and_async
+ def testSimpleWeightRead(self):
+ """Basic remote eager weight read."""
+
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ w = resource_variable_ops.ResourceVariable([[2.0]])
+ loss = w * w
+ np.testing.assert_array_equal([[4.0]], loss.numpy())
+
+ @run_sync_and_async
+ def testTapeWeightRead(self):
+ """Remote eager weight read in a tape."""
+
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ w = resource_variable_ops.ResourceVariable([[3.0]])
+ with backprop.GradientTape() as tape:
+ loss = w * w
+
+ grad = tape.gradient(loss, w)
+ np.testing.assert_array_equal([[9.0]], loss.numpy())
+ np.testing.assert_array_equal([[6.0]], grad.numpy())
+
+ @run_sync_and_async
+ def testServerDefChanged(self):
+ """Update server def, and run ops on new cluster."""
+ context.set_server_def(
+ server_def=get_server_def(
+ ALT_JOB_NAME,
+ local_server_port=0,
+ remote_server_addresses=[
+ self._cached_server1_target, self._cached_server2_target
+ ],
+ task_index=0))
+
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % ALT_JOB_NAME):
+ x1 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x1)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
+ # Set the server def back to JOB_NAME
+ context.set_server_def(
+ server_def=get_server_def(
+ JOB_NAME,
+ local_server_port=0,
+ remote_server_addresses=[
+ self._cached_server1_target, self._cached_server2_target
+ ],
+ task_index=0))
+
+ with ops.device("job:%s/replica:0/task:1/device:CPU:0" % JOB_NAME):
+ x1 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x1)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()