aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Akshay Modi <nareshmodi@google.com>2018-08-24 11:21:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-24 11:35:34 -0700
commit0eade17dc823ed26c0b82987de61f61c8f1886a7 (patch)
tree1edd2fbb71387d269f77b0ef2539ba5b48e8cead /tensorflow/contrib/eager
parent197309b5d56436b523b8b03ddf2a23555c37365e (diff)
Add a helper to be able to connect to cloud TPUs easily in the colab env.
PiperOrigin-RevId: 210127772
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/BUILD14
-rw-r--r--tensorflow/contrib/eager/python/remote.py73
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py13
-rw-r--r--tensorflow/contrib/eager/python/tfe.py3
4 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index fa3f1bb7ad..84517b57c7 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -14,6 +14,7 @@ py_library(
":datasets",
":metrics",
":network",
+ ":remote",
":saver",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
@@ -223,11 +224,24 @@ py_test(
],
)
+py_library(
+ name = "remote",
+ srcs = ["remote.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
py_test(
name = "remote_test",
srcs = ["remote_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":remote",
"//tensorflow/contrib/eager/python:tfe",
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py
new file mode 100644
index 0000000000..b74cf394f6
--- /dev/null
+++ b/tensorflow/contrib/eager/python/remote.py
@@ -0,0 +1,73 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers to connect to remote servers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.core.protobuf.cluster_pb2 import ClusterDef
+from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
+from tensorflow.python.eager import context
+
+
+def connect_to_remote_host(remote_host=None, job_name="worker"):
+ """Connects to a single machine to enable remote execution on it.
+
+ Will make devices on the remote host available to use. Note that calling this
+ more than once will work, but will invalidate any tensor handles on the old
+ remote devices.
+
+ Using the default job_name of worker, you can schedule ops to run remotely as
+ follows:
+ ```python
+ # Enable eager execution, and connect to the remote host.
+ tf.enable_eager_execution()
+ tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876")
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ # The following tensors should be resident on the remote device, and the op
+ # will also execute remotely.
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ ```
+
+ Args:
+ remote_host: The addr of the remote server in host-port format.
+ job_name: The job name under which the new server will be accessible.
+
+ Raises:
+ ValueError: if remote_host is None.
+ """
+ if remote_host is None:
+ raise ValueError("Must provide an remote_host")
+ cluster_def = ClusterDef()
+ job_def = cluster_def.job.add()
+ job_def.name = job_name
+ job_def.tasks[0] = "127.0.0.1:0"
+ job_def.tasks[1] = remote_host
+
+ server_def = ServerDef(
+ cluster=cluster_def,
+ job_name=job_name,
+ task_index=0,
+ protocol="grpc")
+
+ # TODO(nareshmodi): Make this default since it works in more situations.
+ os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1"
+ context.set_server_def(server_def)
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index 76f48eeb1c..13029db975 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -23,6 +23,7 @@ import os
import numpy as np
+from tensorflow.contrib.eager.python import remote
from tensorflow.core.protobuf import cluster_pb2
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.eager import backprop
@@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase):
self._cached_server1_target = self._cached_server1.target[len("grpc://"):]
self._cached_server2_target = self._cached_server2.target[len("grpc://"):]
+ def setUp(self):
# Start the local server.
context.set_server_def(
server_def=get_server_def(
@@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase):
y = math_ops.matmul(x1, x1)
np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+ @run_sync_and_async
+ def testConnectToRemoteServer(self):
+ """Basic server connection."""
+ remote.connect_to_remote_host(self._cached_server1_target)
+
+ with ops.device("job:worker/replica:0/task:1/device:CPU:0"):
+ x1 = array_ops.ones([2, 2])
+ x2 = array_ops.ones([2, 2])
+ y = math_ops.matmul(x1, x2)
+ np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy())
+
if __name__ == "__main__":
ops.enable_eager_execution()
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index 4dfd083443..fe7f1b72fc 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`.
@@TensorSpec
+@@connect_to_cloud_tpu
+
@@DEVICE_PLACEMENT_EXPLICIT
@@DEVICE_PLACEMENT_WARN
@@DEVICE_PLACEMENT_SILENT
@@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network
from tensorflow.contrib.eager.python.network import Sequential
from tensorflow.contrib.eager.python.network import save_network_checkpoint
from tensorflow.contrib.eager.python.network import restore_network_checkpoint
+from tensorflow.contrib.eager.python.remote import connect_to_remote_host
from tensorflow.contrib.eager.python.saver import get_optimizer_variables
from tensorflow.contrib.eager.python.saver import restore_variables_on_create
from tensorflow.contrib.eager.python.saver import Saver