diff options
author | Akshay Modi <nareshmodi@google.com> | 2018-08-24 11:21:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-24 11:35:34 -0700 |
commit | 0eade17dc823ed26c0b82987de61f61c8f1886a7 (patch) | |
tree | 1edd2fbb71387d269f77b0ef2539ba5b48e8cead /tensorflow/contrib/eager | |
parent | 197309b5d56436b523b8b03ddf2a23555c37365e (diff) |
Add a helper to be able to connect to cloud TPUs easily in the colab env.
PiperOrigin-RevId: 210127772
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/BUILD | 14 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/remote.py | 73 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/remote_test.py | 13 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/tfe.py | 3 |
4 files changed, 103 insertions, 0 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index fa3f1bb7ad..84517b57c7 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -14,6 +14,7 @@ py_library( ":datasets", ":metrics", ":network", + ":remote", ":saver", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", @@ -223,11 +224,24 @@ py_test( ], ) +py_library( + name = "remote", + srcs = ["remote.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + "//tensorflow/python/eager:context", + ], +) + py_test( name = "remote_test", srcs = ["remote_test.py"], srcs_version = "PY2AND3", deps = [ + ":remote", "//tensorflow/contrib/eager/python:tfe", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/eager/python/remote.py b/tensorflow/contrib/eager/python/remote.py new file mode 100644 index 0000000000..b74cf394f6 --- /dev/null +++ b/tensorflow/contrib/eager/python/remote.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to connect to remote servers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.protobuf.cluster_pb2 import ClusterDef +from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef +from tensorflow.python.eager import context + + +def connect_to_remote_host(remote_host=None, job_name="worker"): + """Connects to a single machine to enable remote execution on it. + + Will make devices on the remote host available to use. Note that calling this + more than once will work, but will invalidate any tensor handles on the old + remote devices. + + Using the default job_name of worker, you can schedule ops to run remotely as + follows: + ```python + # Enable eager execution, and connect to the remote host. + tf.enable_eager_execution() + tf.contrib.eager.connect_to_remote_host("exampleaddr.com:9876") + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + # The following tensors should be resident on the remote device, and the op + # will also execute remotely. + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + ``` + + Args: + remote_host: The addr of the remote server in host-port format. + job_name: The job name under which the new server will be accessible. + + Raises: + ValueError: if remote_host is None. + """ + if remote_host is None: + raise ValueError("Must provide an remote_host") + cluster_def = ClusterDef() + job_def = cluster_def.job.add() + job_def.name = job_name + job_def.tasks[0] = "127.0.0.1:0" + job_def.tasks[1] = remote_host + + server_def = ServerDef( + cluster=cluster_def, + job_name=job_name, + task_index=0, + protocol="grpc") + + # TODO(nareshmodi): Make this default since it works in more situations. + os.environ["TF_EAGER_REMOTE_USE_SEND_TENSOR_RPC"] = "1" + context.set_server_def(server_def) diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py index 76f48eeb1c..13029db975 100644 --- a/tensorflow/contrib/eager/python/remote_test.py +++ b/tensorflow/contrib/eager/python/remote_test.py @@ -23,6 +23,7 @@ import os import numpy as np +from tensorflow.contrib.eager.python import remote from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.eager import backprop @@ -85,6 +86,7 @@ class RemoteExecutionTest(test.TestCase): self._cached_server1_target = self._cached_server1.target[len("grpc://"):] self._cached_server2_target = self._cached_server2.target[len("grpc://"):] + def setUp(self): # Start the local server. context.set_server_def( server_def=get_server_def( @@ -172,6 +174,17 @@ class RemoteExecutionTest(test.TestCase): y = math_ops.matmul(x1, x1) np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + @run_sync_and_async + def testConnectToRemoteServer(self): + """Basic server connection.""" + remote.connect_to_remote_host(self._cached_server1_target) + + with ops.device("job:worker/replica:0/task:1/device:CPU:0"): + x1 = array_ops.ones([2, 2]) + x2 = array_ops.ones([2, 2]) + y = math_ops.matmul(x1, x2) + np.testing.assert_array_equal([[2, 2], [2, 2]], y.numpy()) + if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 4dfd083443..fe7f1b72fc 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -74,6 +74,8 @@ To use, at program startup, call `tf.enable_eager_execution()`. @@TensorSpec +@@connect_to_cloud_tpu + @@DEVICE_PLACEMENT_EXPLICIT @@DEVICE_PLACEMENT_WARN @@DEVICE_PLACEMENT_SILENT @@ -94,6 +96,7 @@ from tensorflow.contrib.eager.python.network import Network from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint +from tensorflow.contrib.eager.python.remote import connect_to_remote_host from tensorflow.contrib.eager.python.saver import get_optimizer_variables from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver |