aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-02-25 20:10:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-25 20:15:52 -0800
commit00986d48bb646daab659503ad3a713919865f32d (patch)
tree3179208eda8426b346db591f7d98fd836a20f384
parentd27da251bcc4bab7da2f5aecc509b146f9fa1692 (diff)
Initial version of the open-source distributed TensorFlow runtime.
This includes a gRPC server (grpc_tensorflow_server) that can serve as both the master of a distributed TensorFlow computation, and an individual worker in the computation. The GrpcSession class is included to allow client programs (including Python clients) to interact with a server. See tensorflow/core/distributed_runtime/README.md for usage instructions. This change partially addresses issue #23. Change: 115634191
-rw-r--r--WORKSPACE31
-rw-r--r--tensorflow/core/BUILD59
-rw-r--r--tensorflow/core/distributed_runtime/BUILD306
-rw-r--r--tensorflow/core/distributed_runtime/README.md197
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc318
-rw-r--r--tensorflow/core/distributed_runtime/base_rendezvous_mgr.h212
-rw-r--r--tensorflow/core/distributed_runtime/build_graph_options.cc38
-rw-r--r--tensorflow/core/distributed_runtime/build_graph_options.h38
-rw-r--r--tensorflow/core/distributed_runtime/call_options.cc44
-rw-r--r--tensorflow/core/distributed_runtime/call_options.h72
-rw-r--r--tensorflow/core/distributed_runtime/call_options_test.cc39
-rw-r--r--tensorflow/core/distributed_runtime/executor_test.cc407
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc368
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h147
-rw-r--r--tensorflow/core/distributed_runtime/master.cc413
-rw-r--r--tensorflow/core/distributed_runtime/master.h98
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h66
-rw-r--r--tensorflow/core/distributed_runtime/master_interface.h52
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc942
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h38
-rw-r--r--tensorflow/core/distributed_runtime/master_session_interface.h76
-rw-r--r--tensorflow/core/distributed_runtime/master_test.cc423
-rw-r--r--tensorflow/core/distributed_runtime/process_util.cc69
-rw-r--r--tensorflow/core/distributed_runtime/process_util.h39
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.cc91
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.h48
-rw-r--r--tensorflow/core/distributed_runtime/remote_device_test.cc89
-rw-r--r--tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h79
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD341
-rw-r--r--tensorflow/core/distributed_runtime/rpc/async_service_interface.h37
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_call.h227
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc314
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h98
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc137
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h56
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc181
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.h33
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc79
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h27
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc203
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h38
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc116
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h53
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc233
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h97
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc750
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc98
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc123
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc84
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.h73
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc91
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc92
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_util.h48
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc85
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h28
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc415
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h34
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc196
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h57
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc172
-rw-r--r--tensorflow/core/distributed_runtime/simple_graph_execution_state.cc309
-rw-r--r--tensorflow/core/distributed_runtime/simple_graph_execution_state.h156
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h75
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_logger.cc110
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_logger.h81
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_partial.cc98
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_partial.h56
-rw-r--r--tensorflow/core/distributed_runtime/worker_env.h62
-rw-r--r--tensorflow/core/distributed_runtime/worker_interface.h129
-rw-r--r--tensorflow/core/framework/load_library.cc2
-rw-r--r--tensorflow/core/platform/default/build_config.bzl51
-rw-r--r--tensorflow/core/protobuf/master.proto190
-rw-r--r--tensorflow/core/protobuf/master_service.proto105
-rw-r--r--tensorflow/core/protobuf/worker.proto311
-rw-r--r--tensorflow/core/protobuf/worker_service.proto67
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/client/session_test.py5
-rw-r--r--tensorflow/python/ops/nn.py8
-rw-r--r--tensorflow/python/ops/nn_test.py28
79 files changed, 11222 insertions, 37 deletions
diff --git a/WORKSPACE b/WORKSPACE
index 2e1b018e14..26bfa1f15f 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -15,6 +15,37 @@
load("//tensorflow:workspace.bzl", "tf_workspace")
tf_workspace()
+# grpc expects //external:protobuf_clib and //external:protobuf_compiler
+# to point to the protobuf's compiler library.
+bind(
+ name = "protobuf_clib",
+ actual = "//google/protobuf:protoc_lib",
+)
+
+bind(
+ name = "protobuf_compiler",
+ actual = "//google/protobuf:protoc_lib",
+)
+
+git_repository(
+ name = "grpc",
+ commit = "73979f4",
+ init_submodules = True,
+ remote = "https://github.com/grpc/grpc.git",
+)
+
+# protobuf expects //external:grpc_cpp_plugin to point to grpc's
+# C++ plugin code generator.
+bind(
+ name = "grpc_cpp_plugin",
+ actual = "@grpc//:grpc_cpp_plugin",
+)
+
+bind(
+ name = "grpc_lib",
+ actual = "@grpc//:grpc++_unsecure",
+)
+
# TENSORBOARD_BOWER_AUTOGENERATED_BELOW_THIS_LINE_DO_NOT_EDIT
new_git_repository(
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 649b21edcb..54c3270640 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -61,6 +61,7 @@ load("//tensorflow:tensorflow.bzl", "tf_gpu_kernel_library")
load(
"//tensorflow/core:platform/default/build_config.bzl",
"tf_proto_library",
+ "tf_proto_library_cc",
"tf_additional_lib_srcs",
"tf_additional_stream_executor_srcs",
"tf_additional_test_deps",
@@ -77,7 +78,15 @@ load(
tf_proto_library(
name = "protos_all",
- srcs = glob(["**/*.proto"]),
+ srcs = glob(
+ ["**/*.proto"],
+ exclude = [
+ "protobuf/worker.proto",
+ "protobuf/worker_service.proto",
+ "protobuf/master.proto",
+ "protobuf/master_service.proto",
+ ],
+ ),
cc_api_version = 2,
go_api_version = 2,
java_api_version = 2,
@@ -85,6 +94,54 @@ tf_proto_library(
visibility = ["//visibility:public"],
)
+tf_proto_library_cc(
+ name = "worker_proto",
+ srcs = ["protobuf/worker.proto"],
+ cc_api_version = 2,
+ cc_libs = [":protos_all_cc"],
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "worker_service_proto",
+ srcs = ["protobuf/worker_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_grpc_version = 1,
+ cc_libs = [":worker_proto_cc"],
+ cc_stubby_versions = ["2"],
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "master_proto",
+ srcs = ["protobuf/master.proto"],
+ cc_api_version = 2,
+ cc_libs = [":protos_all_cc"],
+ py_api_version = 2,
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
+tf_proto_library_cc(
+ name = "master_service_proto",
+ srcs = ["protobuf/master_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_grpc_version = 1,
+ cc_libs = [":master_proto_cc"],
+ cc_stubby_versions = ["2"],
+ py_api_version = 2,
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
cc_library(
name = "lib",
hdrs = [
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
new file mode 100644
index 0000000000..00d97a6ef9
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -0,0 +1,306 @@
+# Description:
+# A distributed runtime for TensorFlow, which allows graph execution
+# to be distributed and performed in parallel across multiple
+# processes.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_tests")
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_kernel_tests_linkstatic",
+)
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+
+package(default_visibility = [
+ "//tensorflow:internal",
+])
+
+cc_library(
+ name = "worker_env",
+ hdrs = ["worker_env.h"],
+ deps = [],
+)
+
+cc_library(
+ name = "worker_interface",
+ hdrs = ["worker_interface.h"],
+ deps = [
+ ":call_options",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "call_options",
+ srcs = ["call_options.cc"],
+ hdrs = ["call_options.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "call_options_test",
+ size = "small",
+ srcs = ["call_options_test.cc"],
+ deps = [
+ ":call_options",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "worker_cache",
+ hdrs = ["worker_cache.h"],
+ deps = ["//tensorflow/core:protos_all_cc"],
+)
+
+cc_library(
+ name = "remote_device",
+ srcs = ["remote_device.cc"],
+ hdrs = ["remote_device.h"],
+ deps = [
+ ":process_util",
+ ":worker_cache",
+ ":worker_interface",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "master_interface",
+ hdrs = ["master_interface.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "master",
+ srcs = ["master.cc"],
+ hdrs = ["master.h"],
+ deps = [
+ ":call_options",
+ ":master_env",
+ ":master_session_interface",
+ ":process_util",
+ ":remote_device",
+ ":worker_cache",
+ ":worker_interface",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "master_session",
+ srcs = ["master_session.cc"],
+ hdrs = ["master_session.h"],
+ deps = [
+ ":master_env",
+ ":master_session_interface",
+ ":process_util",
+ ":simple_graph_execution_state",
+ ":worker_cache",
+ ":worker_interface",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+cc_library(
+ name = "build_graph_options",
+ srcs = ["build_graph_options.cc"],
+ hdrs = ["build_graph_options.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "simple_graph_execution_state",
+ srcs = ["simple_graph_execution_state.cc"],
+ hdrs = ["simple_graph_execution_state.h"],
+ deps = [
+ ":build_graph_options",
+ ":process_util",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "rendezvous_mgr_interface",
+ srcs = [],
+ hdrs = ["rendezvous_mgr_interface.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ ],
+)
+
+cc_library(
+ name = "master_session_interface",
+ srcs = [],
+ hdrs = ["master_session_interface.h"],
+ deps = ["//tensorflow/core:lib"],
+)
+
+cc_library(
+ name = "base_rendezvous_mgr",
+ srcs = ["base_rendezvous_mgr.cc"],
+ hdrs = ["base_rendezvous_mgr.h"],
+ deps = [
+ ":process_util",
+ ":rendezvous_mgr_interface",
+ ":worker_cache",
+ ":worker_env",
+ ":worker_interface",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+cc_library(
+ name = "master_env",
+ hdrs = ["master_env.h"],
+)
+
+cc_library(
+ name = "graph_mgr",
+ srcs = ["graph_mgr.cc"],
+ hdrs = ["graph_mgr.h"],
+ deps = [
+ ":process_util",
+ ":rendezvous_mgr_interface",
+ ":worker_env",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "process_util",
+ srcs = ["process_util.cc"],
+ hdrs = ["process_util.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+cc_library(
+ name = "worker_cache_partial",
+ srcs = ["worker_cache_partial.cc"],
+ hdrs = ["worker_cache_partial.h"],
+ deps = [
+ ":process_util",
+ ":worker_cache",
+ ":worker_interface",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:worker_proto_cc",
+ ],
+)
+
+cc_library(
+ name = "worker_cache_logger",
+ srcs = ["worker_cache_logger.cc"],
+ hdrs = ["worker_cache_logger.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+# TODO(mrry): Move executor_test.cc to ../common_runtime when once it no longer depends
+# on grpc_testlib.
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tags = tf_cuda_tests_tags(),
+ tests = [
+ "executor_test.cc",
+ "master_test.cc",
+ "remote_device_test.cc",
+ ],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":master",
+ ":process_util",
+ ":remote_device",
+ ":worker_interface",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_channel",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_testlib",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_util",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
+ ],
+)
diff --git a/tensorflow/core/distributed_runtime/README.md b/tensorflow/core/distributed_runtime/README.md
new file mode 100644
index 0000000000..66433e352a
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/README.md
@@ -0,0 +1,197 @@
+# Distributed TensorFlow
+
+This directory contains the initial open-source implementation of the
+distributed TensorFlow runtime, using [gRPC](http://grpc.io) for inter-process
+communication.
+
+## Quick start
+
+To get started, you will need to build the TensorFlow server binary
+(`grpc_tensorflow_server`) and a gRPC-based client. Currently this is only
+available using the source-based installation of TensorFlow, but it will be
+included in future binary releases. You can build the server binary using one of
+the following commands:
+
+```shell
+# CPU-only build.
+$ bazel build -c opt //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
+
+# GPU build.
+$ bazel build -c opt --config=cuda //tensorflow/core/distributed_runtime/rpc:grpc_tensorflow_server
+```
+
+If you build the latest Python (PIP) package from source, it will contain a
+gRPC-based client. If you are using a previous binary release, you may need to
+rebuild and install an up-to-date PIP package by following
+[these installation instructions](https://www.tensorflow.org/versions/master/get_started/os_setup.html#create-the-pip-package-and-install).
+
+Once you have successfully built the distributed TensorFlow components, you can
+test your installation by starting a server as follows:
+
+```shell
+# Start a TensorFlow server as a single-process "cluster".
+$ bazel-bin/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server \
+ --cluster_spec='local|localhost:2222' --job_name=local --task_index=0 &
+```
+
+...then start a Python interpreter and create a remote session:
+
+```python
+$ python
+>>> import tensorflow as tf
+>>> c = tf.constant("Hello, distributed TensorFlow!")
+>>> sess = tf.Session("grpc://localhost:2222")
+>>> sess.run(c)
+'Hello, distributed TensorFlow!'
+```
+
+## Cluster definition
+
+The command-line arguments to `grpc_tensorflow_server` define the membership of a TensorFlow cluster. The `--cluster_spec` flag determines the set of processes in the cluster, as a list of *jobs*, each of which contains a list of *task* endpoints. All processes in the cluster must be started with the same `--cluster_spec`. Example values include:
+
+<table>
+ <tr><th><code>--cluster_spec='...'</code></th><th>Available tasks</th>
+ <tr>
+ <td><code>local|localhost:2222</code></td><td><code>/job:local/task:0</code></td>
+ </tr>
+ <tr>
+ <td><code>local|localhost:2222;localhost:2223</code></td><td><code>/job:local/task:0</code><br/><code>/job:local/task:1</code></td>
+ </tr>
+ <tr>
+ <td><code>worker|worker0:2222;worker1:2222;worker2:2222,</code><br/><code>ps|ps0:2222;ps1:2222</code></td><td><code>/job:worker/task:0</code><br/><code>/job:worker/task:1</code><br/><code>/job:worker/task:2</code><br/><code>/job:ps/task:0</code><br/><code>/job:ps/task:1</code></td>
+ </tr>
+</table>
+
+The `--job_name` and `--task_index` flags indicate which task will run in this
+process, out of the jobs and tasks defined in `--cluster_spec`. For example,
+`--job_name=local --task_index=0` means that the process will be task
+`/job:local/task:0`, and TensorFlow devices in the process will have names
+starting with that prefix.
+
+**N.B.** Manually specifying these command lines can be tedious, especially for
+large clusters. We are working on tools for launching tasks programmatically,
+e.g. using a cluster manager like [Kubernetes](http://kubernetes.io). If there
+are particular cluster managers for which you'd like to see support, please
+raise a [GitHub issue](https://github.com/tensorflow/tensorflow/issues).
+
+## Specifying distributed devices in your model
+
+To place operations on a particular process, you can use the same
+[`tf.device()`](https://www.tensorflow.org/versions/master/api_docs/python/framework.html#device)
+function that is used to specify whether ops run on the CPU or GPU. For example:
+
+```python
+with tf.device("/job:ps/task:0"):
+ weights_1 = tf.Variable(...)
+ biases_1 = tf.Variable(...)
+
+with tf.device("/job:ps/task:1"):
+ weights_2 = tf.Variable(...)
+ biases_2 = tf.Variable(...)
+
+with tf.device("/job:worker/task:7"):
+ input, labels = ...
+ layer_1 = tf.nn.relu(tf.matmul(input, weights_1) + biases_1)
+ logits = tf.nn.relu(tf.matmul(layer_1, weights_2) + biases_2)
+ # ...
+ train_op = ...
+
+with tf.Session("grpc://worker7:2222") as sess:
+ for _ in range(10000):
+ sess.run(train_op)
+```
+
+In the above example, the variables are created on two tasks in the `ps` job,
+and the compute-intensive part of the model is created in the `worker`
+job. TensorFlow will insert the appropriate data transfers between the jobs
+(from `ps` to `worker` for the forward pass, and from `worker` to `ps` for
+applying gradients).
+
+## Replicated training
+
+A common training configuration ("data parallel training") involves multiple
+tasks in a `worker` job training the same model, using shared parameters hosted
+in a one or more tasks in a `ps` job. Each task will typically run on a
+different machine. There are many ways to specify this structure in TensorFlow,
+and we are building libraries that will simplify the work of specifying a
+replicated model. Possible approaches include:
+
+* Building a single graph containing one set of parameters (in `tf.Variable`
+ nodes pinned to `/job:ps`), and multiple copies of the "model" pinned to
+ different tasks in `/job:worker`. Each copy of the model can have a different
+ `train_op`, and one or more client threads can call `sess.run(train_ops[i])`
+ for each worker `i`. This implements *asynchronous* training.
+
+ This approach uses a single `tf.Session` whose target is one of the workers in
+ the cluster.
+
+* As above, but where the gradients from all workers are averaged. See the
+ [CIFAR-10 multi-GPU trainer](https://www.tensorflow.org/code/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py)
+ for an example of this form of replication. The implements *synchronous* training
+
+* The "distributed trainer" approach uses multiple graphs&mdash;one per
+ worker&mdash;where each graph contains one set of parameters (pinned to
+ `/job:ps`) and one copy of the model (pinned to a particular
+ `/job:worker/task:i`). The "container" mechanism is used to share variables
+ between different graphs: when each variable is constructed, the optional
+ `container` argument is specified with the same value in each copy of the
+ graph. For large models, this can be more efficient, because the overall graph
+ is smaller.
+
+ This approach uses multiple `tf.Session` objects: one per worker process,
+ where the `target` of each is the address of a different worker. The
+ `tf.Session` objects can all be created in a single Python client, or you can
+ use multiple Python clients to better distribute the trainer load.
+
+## Glossary
+
+<dl>
+ <dt>Client</dt>
+ <dd>
+ A client is typically a program that builds a TensorFlow graph and
+ constructs a `tensorflow::Session` to interact with a cluster. Clients are
+ typically written in Python or C++. A single client process can directly
+ interact with multiple TensorFlow servers (see "Replicated training" above),
+ and a single server can serve multiple clients.
+ </dd>
+ <dt>Cluster</dt>
+ <dd>
+ A TensorFlow cluster comprises one or more TensorFlow servers, divided into
+ a set of named jobs, which in turn comprise lists of tasks. A cluster is
+ typically dedicated to a particular high-level objective, such as training a
+ neural network, using many machines in parallel.
+ </dd>
+ <dt>Job</dt>
+ <dd>
+ A job comprises a list of "tasks", which typically serve a common
+ purpose. For example, a job named `ps` (for "parameter server") typically
+ hosts nodes that store and update variables; while a job named `worker`
+ typically hosts stateless nodes that perform compute-intensive tasks.
+ The tasks in a job typically run on different machines.
+ </dd>
+ <dt>Master service</dt>
+ <dd>
+ An RPC service that provides remote access to a set of distributed
+ devices. The master service implements the <code>tensorflow::Session</code>
+ interface, and is responsible for coordinating work across one or more
+ "worker services".
+ </dd>
+ <dt>Task</dt>
+ <dd>
+ A task typically corresponds to a single TensorFlow server process,
+ belonging to a particular "job" and with a particular index within that
+ job's list of tasks.
+ </dd>
+
+ <dt>TensorFlow server</dt>
+ <dd>
+ A process running the <code>grpc_tensorflow_server</code> binary, which is a
+ member of a cluster, and exports a "master service" and "worker service".
+ </dd>
+ <dt>Worker service</dt>
+ <dd>
+ An RPC service that executes parts of a TensorFlow graph using its local
+ devices. A worker service implements <a
+ href="./worker_service.proto"><code>worker_service.proto</code></a>.
+ </dd>
+</dl>
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
new file mode 100644
index 0000000000..af5d127248
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.cc
@@ -0,0 +1,318 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* env) : worker_env_(env) {}
+
+BaseRendezvousMgr::~BaseRendezvousMgr() {
+ for (auto& p : table_) {
+ BaseRemoteRendezvous* rendez = p.second;
+ rendez->StartAbort(errors::Aborted("Shutdown"));
+ rendez->Unref();
+ }
+}
+
+Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
+ return FindOrCreate(step_id);
+}
+
+BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
+ mutex_lock l(mu_);
+ Table::iterator iter = table_.find(step_id);
+ if (iter == table_.end()) {
+ auto rr = Create(step_id, worker_env_);
+ iter = table_.insert({step_id, rr}).first;
+ }
+ iter->second->Ref();
+ return iter->second;
+}
+
+void BaseRendezvousMgr::RecvLocalAsync(int64 step_id, const string& key,
+ Rendezvous::DoneCallback done) {
+ BaseRemoteRendezvous* rendez = FindOrCreate(step_id);
+ rendez->RecvLocalAsync(
+ key, [rendez, done](const Status& s, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& v,
+ bool dead) {
+ rendez->Unref();
+ done(s, send_args, recv_args, v, dead);
+ });
+}
+
+Status BaseRendezvousMgr::RecvLocal(int64 step_id, const string& key,
+ Tensor* val, bool* is_dead) {
+ Status ret;
+ Notification n;
+ RecvLocalAsync(step_id, key,
+ [val, is_dead, &ret, &n](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& v, const bool dead) {
+ ret = s;
+ *val = v;
+ *is_dead = dead;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+}
+
+void BaseRendezvousMgr::Cleanup(int64 step_id) {
+ Rendezvous* rendez = nullptr;
+ {
+ mutex_lock l(mu_);
+ Table::iterator iter = table_.find(step_id);
+ if (iter != table_.end()) {
+ rendez = iter->second;
+ table_.erase(iter);
+ }
+ }
+ if (!rendez) return;
+ rendez->StartAbort(errors::Aborted("Cleanup ", step_id));
+ rendez->Unref();
+}
+
+void BaseRendezvousMgr::CleanupAll() {
+ std::vector<Rendezvous*> rendezs;
+ {
+ mutex_lock l(mu_);
+ for (const auto& entry : table_) {
+ rendezs.push_back(entry.second);
+ }
+ table_.clear();
+ }
+ for (auto rendez : rendezs) {
+ rendez->StartAbort(errors::Aborted("Shutdown"));
+ rendez->Unref();
+ }
+}
+
+BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
+ bool tolerate_dup_recv)
+ : env_(env),
+ step_id_(step_id),
+ tolerate_dup_recv_(tolerate_dup_recv),
+ local_(NewLocalRendezvous(tolerate_dup_recv)) {}
+
+BaseRemoteRendezvous::~BaseRemoteRendezvous() {
+ CHECK(active_.empty());
+ local_->Unref();
+}
+
+// Returns true if "device_name" is a valid full name of local device
+// of the "worker". This helper is purely based on the worker name
+// and device name and does no lookups in the worker->device_mgr.
+static bool IsLocalDevice(const WorkerEnv& worker,
+ const StringPiece device_name) {
+ return device_name.starts_with(worker.worker_name);
+}
+
+Status BaseRemoteRendezvous::Send(const string& key,
+ const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) {
+ VLOG(1) << "BaseRemoteRendezvous Send " << this << " " << key;
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ Rendezvous::ParsedKey parsed;
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
+ if (!IsLocalDevice(*env_, parsed.src_device)) {
+ return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
+ env_->worker_name);
+ }
+ // Buffers "val" and "device_context" in local_.
+ return local_->Send(key, args, val, is_dead);
+}
+
+Status BaseRemoteRendezvous::ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed) {
+ {
+ mutex_lock l(mu_);
+ if (!status_.ok()) return status_;
+ }
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, parsed));
+ if (is_src && !IsLocalDevice(*env_, parsed->src_device)) {
+ return errors::InvalidArgument("Invalid rendezvous key (src): ", key, " @ ",
+ env_->worker_name);
+ }
+ if (!is_src && !IsLocalDevice(*env_, parsed->dst_device)) {
+ return errors::InvalidArgument("Invalid rendezvous key (dst): ", key, " @ ",
+ env_->worker_name);
+ }
+ return Status::OK();
+}
+
+void BaseRemoteRendezvous::SameWorkerRecvDone(
+ const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args, const Tensor& in, Tensor* out,
+ StatusCallback done) {
+ // Do a quick copy (sharing the underlying buffer) if both tensors
+ // are on host memory.
+ const bool src_host =
+ (send_args.alloc_attrs.on_host() || parsed.src.type == "CPU");
+ const bool dst_host =
+ (recv_args.alloc_attrs.on_host() || parsed.dst.type == "CPU");
+ if (src_host && dst_host) {
+ *out = in;
+ done(Status::OK());
+ return;
+ }
+
+ // This copy must involve a GPU. Hence, "in" must support DMA
+ // (e.g., string tensors do not work on GPU).
+ if (!DMAHelper::CanUseDMA(&in)) {
+ done(errors::InvalidArgument("Non-DMA-safe ", DataTypeString(in.dtype()),
+ " tensor may not be copied from/to a GPU."));
+ return;
+ }
+
+ Device* src_device;
+ Status s = env_->device_mgr->LookupDevice(parsed.src_device, &src_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+ Device* dst_device;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ if (!s.ok()) {
+ done(s);
+ return;
+ }
+
+ AllocatorAttributes attr = recv_args.alloc_attrs;
+ attr.set_gpu_compatible(send_args.alloc_attrs.gpu_compatible() ||
+ recv_args.alloc_attrs.gpu_compatible());
+ Allocator* out_allocator = dst_device->GetAllocator(attr);
+ Tensor copy(out_allocator, in.dtype(), in.shape());
+ *out = copy;
+
+ // The following function takes care of cpu->gpu, gpu->cpu, gpu->gpu copies,
+ // etc.
+ CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
+ recv_args.device_context, src_device, dst_device,
+ send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
+ done);
+}
+
+bool BaseRemoteRendezvous::IsSameWorker(DeviceNameUtils::ParsedName src,
+ DeviceNameUtils::ParsedName dst) {
+ return DeviceNameUtils::IsSameAddressSpace(src, dst);
+}
+
+void BaseRemoteRendezvous::RecvAsync(const string& key,
+ const Rendezvous::Args& recv_args,
+ DoneCallback done) {
+ VLOG(1) << "RemoteRendezvous Recv " << this << " " << key;
+
+ Rendezvous::ParsedKey parsed;
+ Status s = ParseKey(key, false /*!is_src*/, &parsed);
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor(), false);
+ return;
+ }
+
+ // Are src and dst in the same worker?
+ if (IsSameWorker(parsed.src, parsed.dst)) {
+ // Recv the tensor from local_.
+ local_->RecvAsync(
+ key, recv_args, [this, parsed, done](const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& in, bool is_dead) {
+ Status s = status;
+ Tensor* out = new Tensor;
+ StatusCallback final_callback = [done, send_args, recv_args, out,
+ is_dead](const Status& s) {
+ done(s, send_args, recv_args, *out, is_dead);
+ delete out;
+ };
+
+ if (s.ok()) {
+ SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
+ final_callback);
+ } else {
+ final_callback(s);
+ }
+ });
+ return;
+ } else {
+ RecvFromRemoteAsync(key, parsed, recv_args, done);
+ }
+}
+
+void BaseRemoteRendezvous::RecvLocalAsync(const string& key,
+ DoneCallback done) {
+ Rendezvous::ParsedKey parsed;
+ Status s = ParseKey(key, true /* is_src */, &parsed);
+ if (!s.ok()) {
+ done(s, Args(), Args(), Tensor(), false);
+ return;
+ }
+ local_->RecvAsync(key, Args(), done);
+}
+
+void BaseRemoteRendezvous::StartAbort(const Status& s) {
+ CHECK(!s.ok());
+ local_->StartAbort(s);
+ {
+ // Aborts all active RecvTensor calls.
+ mutex_lock l(mu_);
+ if (status_.ok()) {
+ status_ = s;
+ for (BaseRecvTensorCall* call : active_) {
+ call->StartAbort(s);
+ }
+ active_.clear();
+ }
+ }
+}
+
+void BaseRemoteRendezvous::RegisterCall(BaseRecvTensorCall* call) {
+ mutex_lock l(mu_);
+ if (!status_.ok()) {
+ call->StartAbort(status_);
+ } else {
+ CHECK(active_.insert(call).second);
+ }
+}
+
+void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
+ mutex_lock l(mu_);
+ active_.erase(call);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
new file mode 100644
index 0000000000..2674817426
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/base_rendezvous_mgr.h
@@ -0,0 +1,212 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+class BaseRemoteRendezvous;
+class BaseRecvTensorCall;
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received. Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+// fork execution of a graph executor using "rendez" on thread 1;
+// fork execution of another graph executor using "rendez" on thread 2;
+// ...
+// join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through `rendez`.
+//
+// Tensors sent and received through a rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
+class BaseRendezvousMgr : public RendezvousMgrInterface {
+ public:
+ explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
+ ~BaseRendezvousMgr() override;
+
+ // Returns Rendezvous supporting send and recv among workers in the
+ // "step_id". The caller takes ownership of one reference on the
+ // returned Rendezvous instance.
+ Rendezvous* Find(int64 step_id) override;
+
+ // Finds the local rendezvous instance for the "step_id". Runs
+ // "done" when the tensor for "key" is produced or an error occurs.
+ //
+ // This method is used by the rpc handler of RecvTensor.
+ void RecvLocalAsync(int64 step_id, const string& key,
+ Rendezvous::DoneCallback done) override;
+
+ // Synchronous wrapper for RecvLocalAsync.
+ Status RecvLocal(int64 step_id, const string& key, Tensor* val,
+ bool* is_dead) override;
+
+ // Removes rendezvous for "step_id".
+ //
+ // TODO(zhifengc): Have a background thread in worker that
+ // periodically calls CleanupAll().
+ void Cleanup(int64 step_id) override;
+
+ // Removed all rendezvous.
+ void CleanupAll() override;
+
+ protected:
+ virtual BaseRemoteRendezvous* Create(int64 step_id,
+ const WorkerEnv* worker_env) = 0;
+
+ private:
+ // Maps step_id to rendezvous.
+ typedef std::unordered_map<int64, BaseRemoteRendezvous*> Table;
+
+ // Not owned.
+ const WorkerEnv* const worker_env_;
+
+ mutex mu_;
+ Table table_ GUARDED_BY(mu_);
+
+ BaseRemoteRendezvous* FindOrCreate(int64 step_id);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
+};
+
+// RemoteRendezvous is a Rendezvous which can handle either
+// the producer or consumer being in a remote process.
+//
+// Buffering of Tensor values is delegated to a "local" Rendezvous
+// obtained from NewLocalRendezvous(). This class just adds
+// functionality to coordinate with remote workers.
+class BaseRemoteRendezvous : public Rendezvous {
+ public:
+ BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
+ bool tolerate_dup_recv);
+
+ // Forwards to local_, where the Tensor "val" will be buffered and
+ // any waiting callback stored.
+ Status Send(const string& key, const Rendezvous::Args& args,
+ const Tensor& val, const bool is_dead) override;
+
+ // This method is called only by the RecvOp. It tests to see
+ // whether the value will be produced by a local or remote device
+ // and handles accordingly. In the local case it forwards to
+ // local_, in the remote case it initiates an RPC request.
+ void RecvAsync(const string& key, const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ void StartAbort(const Status& status) override;
+
+ // This method is called only by the local Worker, forwarded through
+ // the same method on RendezvousMgr. This occurs when the Worker
+ // has received a RecvTensor request, either locally or over the
+ // network. In either case it needs to retrieve a locally buffered
+ // value from local_, and give it to its caller.
+ //
+ // Runs "done" as soon as the tensor for "key" is available or an error
+ // is detected.
+ //
+ // REQUIRES: "key" is one that will be Saved into the local rendezvous.
+ void RecvLocalAsync(const string& key, DoneCallback done);
+
+ protected:
+ virtual void RecvFromRemoteAsync(const string& key,
+ const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& args,
+ DoneCallback done) = 0;
+
+ // Returns true if "src" and "dst" are located in the same worker,
+ // and hence may use a local rendezvous.
+ virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
+ DeviceNameUtils::ParsedName dst);
+
+ // If aborted, aborts "call". Otherwise, adds "call" into active_.
+ void RegisterCall(BaseRecvTensorCall* call);
+
+ // Removes "call" from active_ if "call" is in active_.
+ void DeregisterCall(BaseRecvTensorCall* call);
+
+ ~BaseRemoteRendezvous() override;
+
+ const WorkerEnv* const env_; // Not owned.
+ const int64 step_id_;
+
+ private:
+ const bool tolerate_dup_recv_;
+ Rendezvous* local_; // Owns a Ref on this object.
+
+ mutable mutex mu_;
+
+ // Status given by StartAbort() if any.
+ Status status_ GUARDED_BY(mu_);
+
+ // Active outstanding RecvTensor calls.
+ std::unordered_set<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
+
+ // Parses "key" into "parsed". If "is_src" is true, checks that the
+ // rendezvous key's source is in this process. If "is_src" is false,
+ // checks that the rendezvous key's destination is in this process.
+ Status ParseKey(const string& key, bool is_src,
+ Rendezvous::ParsedKey* parsed);
+
+ // Callback handling the case when a rendezvous has been
+ // accomplished in local_ and the consumer is local to this process.
+ // Tensor "in" will be copied into "out". The key "parsed" encodes
+ // the src and dst devices.
+ void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& in_args,
+ const Rendezvous::Args& out_args, const Tensor& in,
+ Tensor* out, StatusCallback done);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
+};
+
+class BaseRecvTensorCall {
+ public:
+ BaseRecvTensorCall() {}
+ virtual ~BaseRecvTensorCall() {}
+
+ virtual void Start(std::function<void()> recv_done) = 0;
+
+ virtual void StartAbort(const Status& s) = 0;
+
+ virtual Status status() const = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BASE_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/build_graph_options.cc b/tensorflow/core/distributed_runtime/build_graph_options.cc
new file mode 100644
index 0000000000..05c42e89ba
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/build_graph_options.cc
@@ -0,0 +1,38 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/build_graph_options.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+
+string BuildGraphOptions::DebugString() const {
+ string rv = "Feed endpoints: ";
+ for (auto& s : feed_endpoints) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ strings::StrAppend(&rv, "\nFetch endpoints: ");
+ for (auto& s : fetch_endpoints) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ strings::StrAppend(&rv, "\nTarget nodes: ");
+ for (auto& s : target_nodes) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ return rv;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/build_graph_options.h b/tensorflow/core/distributed_runtime/build_graph_options.h
new file mode 100644
index 0000000000..438912642d
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/build_graph_options.h
@@ -0,0 +1,38 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
+
+#include <vector>
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+struct BuildGraphOptions {
+ std::vector<string> feed_endpoints;
+ std::vector<string> fetch_endpoints;
+
+ // TODO(vrv): Remove this when we unify target_nodes and fetch_endpoint,
+ // the former via "ref" fetch_endpoints.
+ std::vector<string> target_nodes;
+
+ string DebugString() const;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_BUILD_GRAPH_OPTIONS_H_
diff --git a/tensorflow/core/distributed_runtime/call_options.cc b/tensorflow/core/distributed_runtime/call_options.cc
new file mode 100644
index 0000000000..b9d583b754
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/call_options.cc
@@ -0,0 +1,44 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/call_options.h"
+
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+CallOptions::CallOptions() {}
+
+void CallOptions::StartCancel() {
+ mutex_lock l(mu_);
+ if (cancel_func_ != nullptr) {
+ // NOTE: We must call the cancel_func_ with mu_ held. This ensure
+ // that ClearCancelCallback() does not race with StartCancel().
+ cancel_func_();
+ // NOTE: We can clear cancel_func_ if needed.
+ }
+}
+
+void CallOptions::SetCancelCallback(CancelFunction cancel_func) {
+ mutex_lock l(mu_);
+ cancel_func_ = cancel_func;
+}
+
+void CallOptions::ClearCancelCallback() {
+ mutex_lock l(mu_);
+ cancel_func_ = nullptr;
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/call_options.h b/tensorflow/core/distributed_runtime/call_options.h
new file mode 100644
index 0000000000..de0b85f692
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/call_options.h
@@ -0,0 +1,72 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
+
+#include <functional>
+
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Options passed to interface calls. This class provides portable
+// functionality across different RPC systems on top of
+// platform-specific mechanisms (for client and server contexts,
+// cancellation, etc.).
+//
+// TODO(zhifengc): Maybe change all RPC methods to take CallOptions.
+class CallOptions {
+ public:
+ CallOptions();
+
+ // Cancellation.
+ //
+ // The caller may call StartCancel() anytime as long as this
+ // CallOptions object is alive. The callee may or may not receive
+ // the cancellation notification depending on the rpc layer
+ // implementation.
+ void StartCancel();
+
+ // The callee (the rpc layer implementation) must set a cancellation
+ // notifier before its blocking operation and clear the notifier
+ // before the call returns.
+ //
+ // "cancel_func" may be called zero, once or more time. Therefore, it
+ // should _not_ be responsible for memory management of any objects.
+ //
+ // "cancel_func" must be very light-weight. It should not block on
+ // IO or locking. Typically, it just calls the rpc implementation
+ // layer's specific cancellation mechanism and does nothing else.
+ //
+ // NOTE: "cancel_func" itself is pass-by-value. Therefore, we do not
+ // worry about its ownership here.
+ typedef std::function<void()> CancelFunction;
+ void SetCancelCallback(CancelFunction cancel_func);
+ void ClearCancelCallback();
+
+ private:
+ mutex mu_;
+ CancelFunction cancel_func_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(CallOptions);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_CALL_OPTIONS_H_
diff --git a/tensorflow/core/distributed_runtime/call_options_test.cc b/tensorflow/core/distributed_runtime/call_options_test.cc
new file mode 100644
index 0000000000..62fe21341c
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/call_options_test.cc
@@ -0,0 +1,39 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/call_options.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(CallOptions, Cancel) {
+ int num_calls = 0;
+ CallOptions opts;
+ opts.StartCancel();
+ EXPECT_EQ(num_calls, 0);
+ opts.SetCancelCallback([&num_calls]() { num_calls++; });
+ EXPECT_EQ(num_calls, 0);
+ opts.StartCancel();
+ EXPECT_EQ(num_calls, 1);
+ opts.StartCancel();
+ EXPECT_EQ(num_calls, 2);
+ opts.ClearCancelCallback();
+ EXPECT_EQ(num_calls, 2);
+ opts.StartCancel();
+ EXPECT_EQ(num_calls, 2);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc
new file mode 100644
index 0000000000..be46c73aa2
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/executor_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+ ExecutorTest()
+ : device_(DeviceFactory::NewDevice("CPU", {},
+ "/job:localhost/replica:0/task:0")),
+
+ step_stats_collector_(&step_stats_) {
+ SessionOptions options;
+ thread_pool_ = ComputePool(options);
+ }
+
+ ~ExecutorTest() override {
+ // There should always be exactly one Ref left on the Rendezvous
+ // when the test completes.
+ CHECK(rendez_->Unref());
+ delete exec_;
+ delete device_;
+ }
+
+ // Resets executor_ with a new executor based on a graph 'gdef'.
+ void Create(const Graph* graph) {
+ const int version = graph->versions().producer();
+ LocalExecutorParams params;
+ params.device = device_;
+ params.create_kernel = [this, version](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ };
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ delete exec_;
+ TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_));
+ runner_ = [this](std::function<void()> fn) { thread_pool_->Schedule(fn); };
+ rendez_ = NewLocalRendezvous();
+ }
+
+ Status Run(Rendezvous* rendez) {
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.stats_collector = &step_stats_collector_;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ thread::ThreadPool* thread_pool_ = nullptr;
+ Device* device_ = nullptr;
+ Executor* exec_ = nullptr;
+ StepStatsCollector step_stats_collector_;
+ StepStats step_stats_;
+ Executor::Args::Runner runner_;
+ Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = val;
+ return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int32>()() = val;
+ return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+ Tensor tensor(DT_BOOL, TensorShape({}));
+ tensor.scalar<bool>()() = val;
+ return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+ Tensor tensor(DT_DOUBLE, TensorShape({}));
+ tensor.scalar<double>()() = val;
+ return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_FLOAT);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<float>()();
+}
+
+static uint64 kIncarnation = 1; // Uses in following tests.
+
+string Key(const string& sender, const uint64 incarnation,
+ const string& receiver, const string& name) {
+ return Rendezvous::CreateKey(sender, incarnation, receiver, name,
+ FrameAndIter(0, 0));
+}
+
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+TEST_F(ExecutorTest, SimpleAdd) {
+ // c = a + b
+ Graph* g = new Graph(OpRegistry::Global());
+ auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
+ auto tmp = test::graph::Add(g, in0, in1);
+ test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
+ Create(g);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
+ false)); // in0 = 1.0
+ TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"), args, V(1.0),
+ false)); // in1 = 1.0
+ TF_ASSERT_OK(Run(rendez_));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ TF_ASSERT_OK(
+ rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
+ EXPECT_EQ(2.0, V(out)); // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+ // v0 <- a
+ // v1 = v0 + v0
+ // v2 = v1 + v1
+ // ... ...
+ // v10 = v9 + v9
+ //
+ // b <- v10
+ // All nodes are executed by one thread.
+ Graph* g = new Graph(OpRegistry::Global());
+ auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ const int N = 10;
+ for (int i = 1; i <= N; ++i) {
+ v = test::graph::Add(g, v, v);
+ }
+ // out <- v10
+ test::graph::Send(g, v, "b", BOB, 1, ALICE);
+ Create(g);
+ Rendezvous::Args args;
+ // a = 1.0
+ TF_ASSERT_OK(
+ rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
+ TF_ASSERT_OK(Run(rendez_));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ TF_ASSERT_OK(
+ rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
+ EXPECT_EQ(1024.0, V(out)); // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+// a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+// a + ((a + a) + a)
+// (a + a) + (a + a)
+// ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+ CHECK_GT(N, 1);
+ // A single input node "in".
+ auto in = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ std::vector<Node*> nodes;
+ int i = 0;
+ // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+ for (; i < N; ++i) {
+ nodes.push_back(test::graph::Identity(g, in, 0));
+ }
+ random::PhiloxRandom philox(testing::RandomSeed(), 17);
+ random::SimplePhilox rnd(&philox);
+ while (nodes.size() > 1) {
+ // Randomly pick two from nodes and add them. The resulting node
+ // is named lik n10, n11, .... and is put back into "nodes".
+ int x = rnd.Uniform(nodes.size());
+ auto in0 = nodes[x];
+ nodes[x] = nodes.back();
+ nodes.resize(nodes.size() - 1);
+ x = rnd.Uniform(nodes.size());
+ auto in1 = nodes[x];
+ // node = in0 + in1.
+ nodes[x] = test::graph::Add(g, in0, in1);
+ }
+ // The final output node "out".
+ test::graph::Send(g, nodes.back(), "b", BOB, 1, ALICE);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+ Graph* g = new Graph(OpRegistry::Global());
+ BuildTree(4096, g);
+ Create(g);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(
+ rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false));
+ TF_ASSERT_OK(Run(rendez_));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ TF_ASSERT_OK(
+ rendez_->Recv(Key(BOB, kIncarnation, ALICE, "b"), args, &out, &is_dead));
+ EXPECT_EQ(4096.0, V(out));
+}
+
+void BuildConcurrentAddAssign(Graph* g) {
+ auto one = test::graph::Constant(g, V(1.0));
+ // A variable holds one float.
+ auto var = test::graph::Var(g, DT_FLOAT, TensorShape({}));
+ // Initilize the variable with 1.0.
+ auto init = test::graph::Assign(g, var, one);
+ // Output
+ auto out = test::graph::Send(g, var, "out", ALICE, kIncarnation, BOB);
+ // Have many concurrent computation. Each does v = v + 1.
+ for (int i = 0; i < 1024; ++i) {
+ auto add = test::graph::Add(g, var, one);
+ g->AddControlEdge(init, add); // Ensures run after init.
+ auto assign = test::graph::Assign(g, var, add);
+ g->AddControlEdge(assign, out);
+ }
+}
+
+#ifndef THREAD_SANITIZER
+TEST_F(ExecutorTest, ConcurrentAddAssign) {
+ Graph* g = new Graph(OpRegistry::Global());
+ BuildConcurrentAddAssign(g);
+ Create(g);
+ for (int iters = 0; iters < 16; ++iters) {
+ Rendezvous* rendez = NewLocalRendezvous();
+ TF_ASSERT_OK(Run(rendez));
+ Rendezvous::Args args;
+ Tensor out;
+ bool is_dead;
+ TF_ASSERT_OK(rendez->Recv(Key(ALICE, kIncarnation, BOB, "out"), args, &out,
+ &is_dead));
+ VLOG(1) << "Get " << V(out);
+ EXPECT_LE(V(out), 1025.0);
+ rendez->Unref();
+ }
+}
+#endif
+
+TEST_F(ExecutorTest, SimpleSwitchLive) {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ auto in1 = test::graph::Constant(g, VB(false));
+ auto tmp = test::graph::Switch(g, in0, in1);
+ test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
+ Create(g);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
+ false)); // in0 = 1.0
+ TF_ASSERT_OK(Run(rendez_));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ TF_ASSERT_OK(
+ rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
+ EXPECT_EQ(1.0, V(out)); // out = 1.0
+ EXPECT_FALSE(is_dead);
+}
+
+TEST_F(ExecutorTest, SimpleSwitchDead) {
+ Graph* g = new Graph(OpRegistry::Global());
+ auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ auto in1 = test::graph::Constant(g, VB(true));
+ auto tmp = test::graph::Switch(g, in0, in1);
+ test::graph::Send(g, tmp, "c", BOB, 1, ALICE);
+ Create(g);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0),
+ false)); // in0 = 1.0
+ TF_ASSERT_OK(Run(rendez_));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ TF_ASSERT_OK(
+ rendez_->Recv(Key(BOB, kIncarnation, ALICE, "c"), args, &out, &is_dead));
+ EXPECT_TRUE(is_dead);
+}
+
+TEST_F(ExecutorTest, Abort) {
+ // e = a + b + c + d
+ Graph* g = new Graph(OpRegistry::Global());
+ auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB);
+ auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB);
+ auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB);
+ auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB);
+ auto add0 = test::graph::Add(g, in0, in1);
+ auto add1 = test::graph::Add(g, in2, in3);
+ auto add2 = test::graph::Add(g, add0, add1);
+ test::graph::Send(g, add2, "e", BOB, 1, ALICE);
+ Create(g);
+
+ // Needs 4 inputs (recv). One of them is aborted.
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100 * 1000);
+ Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"),
+ Rendezvous::Args(), V(1.0), false);
+ rendez_->Unref();
+ });
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100 * 1000);
+ Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "b"),
+ Rendezvous::Args(), V(1.0), false);
+ rendez_->Unref();
+ });
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100 * 1000);
+ Status s = rendez_->Send(Key(ALICE, kIncarnation, BOB, "c"),
+ Rendezvous::Args(), V(1.0), false);
+ rendez_->Unref();
+ });
+ rendez_->Ref();
+ SchedClosure([this]() {
+ Env::Default()->SleepForMicroseconds(100 * 1000);
+ rendez_->StartAbort(errors::Aborted(""));
+ rendez_->Unref();
+ });
+ EXPECT_TRUE(errors::IsAborted(Run(rendez_)));
+ Tensor out = V(-1);
+ bool is_dead = false;
+ EXPECT_TRUE(errors::IsAborted(rendez_->Recv(
+ Key(BOB, kIncarnation, ALICE, "c"), Rendezvous::Args(), &out, &is_dead)));
+ // At this point there can still be pending (albeit Aborted) Send
+ // closures holding Refs on rendez_. We need to wait for them, or
+ // else there can be a memory leak at termination.
+ while (!rendez_->RefCountIsOne())
+ ;
+}
+
+TEST_F(ExecutorTest, RecvInvalidDtype) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // An input vector of type float of size 1.
+ auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB);
+ // A floating point variable vector of size 1.
+ auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1}));
+ // Initialize the variable with input.
+ auto init = test::graph::Assign(g, var, one);
+ // Output
+ auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE);
+ g->AddControlEdge(init, two); // Ensures run after init.
+ Create(g);
+ Rendezvous* rendez = NewLocalRendezvous();
+ // Send a double instead of float.
+ TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(),
+ VD(1.0), false));
+ // Fails due to invalid dtype.
+ EXPECT_TRUE(errors::IsInternal(Run(rendez)));
+ Tensor output;
+ bool is_dead;
+ EXPECT_TRUE(errors::IsInternal(rendez->Recv(
+ Key(BOB, 1, ALICE, "two"), Rendezvous::Args(), &output, &is_dead)));
+ rendez->Unref();
+}
+
+TEST_F(ExecutorTest, RecvInvalidRefDtype) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // A var that always produces as invalid dtype.
+ auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE);
+ test::graph::Send(g, var, "out", BOB, 1, ALICE);
+ Create(g);
+ Rendezvous* rendez = NewLocalRendezvous();
+ EXPECT_TRUE(errors::IsInternal(Run(rendez)));
+ Tensor output;
+ bool is_dead;
+ EXPECT_TRUE(errors::IsInternal(rendez->Recv(
+ Key(BOB, 1, ALICE, "out"), Rendezvous::Args(), &output, &is_dead)));
+ rendez->Unref();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
new file mode 100644
index 0000000000..f1bcbf3956
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -0,0 +1,368 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+
+#include <vector>
+
+#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/graph_optimizer.h"
+#include "tensorflow/core/common_runtime/memory_types.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/config.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/graph/validate.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+GraphMgr::GraphMgr(const WorkerEnv* worker_env)
+ : worker_env_(worker_env), table_(5) {}
+
+GraphMgr::~GraphMgr() {
+ for (auto p : table_) p.second->Unref();
+}
+
+GraphMgr::Item::~Item() {
+ for (const auto& unit : this->units) {
+ CHECK_NOTNULL(unit.device);
+ delete unit.root;
+ delete unit.lib;
+ unit.device->op_segment()->RemoveHold(this->session);
+ }
+ delete this->lib_def;
+}
+
+// NOTE: node->device_name() is not set by GraphConstructor. We
+// expects that NodeDef in GraphDef given to workers fully specifies
+// device names.
+static string SplitByDevice(const Node* node) {
+ return node->assigned_device_name();
+}
+
+// Validates "gdef" device specifications.
+static Status ValidateGraphDefForDevices(const GraphDef& gdef) {
+ DeviceNameUtils::ParsedName parsed;
+ for (const auto& ndef : gdef.node()) {
+ if (!DeviceNameUtils::ParseFullName(ndef.device(), &parsed)) {
+ return errors::InvalidArgument("Missing device name in: ",
+ SummarizeNodeDef(ndef));
+ }
+ }
+ return Status::OK();
+}
+
+// Creates executors given a graph definition "gdef" of a "session".
+// If a node in "gdef" is shared by other graphs in "session", the
+// same op kernel is reused. E.g., typically a params node is shared
+// by multiple graphs in a session.
+//
+// If "gdef" is assigned to multiple devices, extra nodes (e.g.,
+// send/recv nodes) maybe added. The extra nodes' name are generated
+// by calling "new_name(old_name)".
+//
+// "executors" are filled with one executor per device if success and
+// the caller takes the ownership of returned executors.
+Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
+ const GraphOptions& graph_options, Item* item) {
+ item->session = session;
+ item->lib_def = new FunctionLibraryDefinition(gdef.library());
+
+ TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
+
+ if (gdef.versions().producer() >= 5) {
+ // Validate the graph: we assume that merging two valid graphs
+ // should maintain graph validity.
+ TF_RETURN_IF_ERROR(graph::ValidateGraphDef(gdef, *item->lib_def));
+ }
+
+ // Constructs the graph out of "gdef".
+ Graph graph(item->lib_def);
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = true;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
+
+ // Splits "graph" into multiple subgraphs by device names.
+ std::unordered_map<string, GraphDef> partitions;
+ PartitionOptions popts;
+ popts.node_to_loc = SplitByDevice;
+ popts.new_name = [this](const string& prefix) {
+ mutex_lock l(mu_);
+ return strings::StrCat(prefix, "_G", next_id_++);
+ };
+ popts.get_incarnation = [this](const string& name) {
+ Device* device = nullptr;
+ Status s = worker_env_->device_mgr->LookupDevice(name, &device);
+ if (s.ok()) {
+ return device->attributes().incarnation();
+ } else {
+ return PartitionOptions::kIllegalIncarnation;
+ }
+ };
+ popts.control_flow_added = true;
+ popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
+ TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
+ if (popts.scheduling_for_recvs) {
+ TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
+ }
+
+ thread::ThreadPool* pool = worker_env_->compute_pool;
+ auto runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
+
+ LocalExecutorParams params;
+
+ Status s;
+ item->units.reserve(partitions.size());
+ const auto& optimizer_opts = graph_options.optimizer_options();
+ GraphOptimizer optimizer(optimizer_opts);
+ for (auto&& p : partitions) {
+ const string& device_name = p.first;
+ GraphDef* def = &p.second;
+ item->units.resize(item->units.size() + 1);
+ ExecutionUnit* unit = &(item->units.back());
+
+ // Find the device.
+ s = worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
+ if (!s.ok()) break;
+
+ // Construct the subgraph.
+ Graph* subgraph = new Graph(item->lib_def);
+ // Give the device an opportunity to rewrite its subgraph.
+ unit->device->MaybeRewriteGraph(gdef.library(), def);
+ s = ConvertGraphDefToGraph(opts, *def, subgraph);
+ if (!s.ok()) {
+ delete subgraph;
+ break;
+ }
+ // Top-level nodes in the graph uses the op segment to cache
+ // kernels. Therefore, as long as the executor is alive, we need
+ // to ensure the kernels cached for the session are alive.
+ auto opseg = unit->device->op_segment();
+ opseg->AddHold(session);
+
+ // Function library runtime.
+ unit->lib = NewFunctionLibraryRuntime(
+ unit->device, runner, def->versions().producer(), item->lib_def,
+ graph_options.optimizer_options());
+
+ // Construct the root executor for the subgraph.
+ params.device = unit->device;
+ auto lib = unit->lib;
+ params.function_library = lib;
+ params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
+ OpKernel** kernel) {
+ // Caches the kernel only if the node is stateful.
+ if (!lib->IsStateful(ndef.op())) {
+ return lib->CreateKernel(ndef, kernel);
+ }
+ auto create_fn = [lib, &ndef](OpKernel** kernel) {
+ return lib->CreateKernel(ndef, kernel);
+ };
+ // Kernels created for subgraph nodes need to be cached. On
+ // cache miss, create_fn() is invoked to create a kernel based
+ // on the function library here + global op registry.
+ return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
+ };
+ params.delete_kernel = [lib](OpKernel* kernel) {
+ // If the node is stateful, opseg owns it. Otherwise, delete it.
+ if (kernel && !lib->IsStateful(kernel->type_string())) {
+ delete kernel;
+ }
+ };
+
+ optimizer.Optimize(lib, &subgraph);
+ s = ValidateMemoryTypes(DeviceType(unit->device->device_type()), subgraph);
+ if (!s.ok()) {
+ delete subgraph;
+ break;
+ }
+ s = NewLocalExecutor(params, subgraph, &unit->root);
+ if (!s.ok()) {
+ break;
+ }
+ }
+ return s;
+}
+
+Status GraphMgr::Register(const string& session, const GraphDef& gdef,
+ const GraphOptions& graph_options, string* handle) {
+ Item* item = new Item;
+ Status s = InitItem(session, gdef, graph_options, item);
+ if (!s.ok()) {
+ item->Unref();
+ return s;
+ }
+
+ // Inserts one item into table_.
+ {
+ mutex_lock l(mu_);
+ *handle = strings::Printf("%016llx", ++next_id_);
+ item->handle = *handle;
+ CHECK(table_.insert({*handle, item}).second);
+ }
+ return Status::OK();
+}
+
+Status GraphMgr::Deregister(const string& handle) {
+ Item* item = nullptr;
+ // Removes one item from table_.
+ {
+ mutex_lock l(mu_);
+ auto iter = table_.find(handle);
+ if (iter == table_.end()) {
+ return errors::Aborted("Graph handle is not found: ", handle,
+ ". Possibly, this worker just restarted.");
+ }
+ item = iter->second;
+ table_.erase(iter);
+ }
+ item->Unref();
+ return Status::OK();
+}
+
+Status GraphMgr::DeregisterAll() {
+ std::vector<Item*> items;
+ // Removes all items from table_.
+ {
+ mutex_lock l(mu_);
+ for (const auto& entry : table_) {
+ items.push_back(entry.second);
+ }
+ table_.clear();
+ }
+ for (auto item : items) {
+ item->Unref();
+ }
+ return Status::OK();
+}
+
+Status GraphMgr::Execute(const string& handle, const int64 step_id,
+ const ExecutorOpts& opts,
+ StepStatsCollector* collector,
+ CancellationManager* cancellation_manager,
+ const NamedTensors& in, NamedTensors* out) {
+ Notification n;
+ Status status;
+ ExecuteAsync(handle, step_id, opts, collector, cancellation_manager, in, out,
+ [&n, &status](const Status& s) {
+ status = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return status;
+}
+
+void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
+ const ExecutorOpts& opts,
+ StepStatsCollector* collector,
+ CancellationManager* cancellation_manager,
+ const NamedTensors& in, NamedTensors* out,
+ StatusCallback done) {
+ // Lookup an item. Holds one ref while executing.
+ Item* item = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto iter = table_.find(handle);
+ if (iter != table_.end()) {
+ item = iter->second;
+ item->Ref();
+ }
+ }
+
+ if (item == nullptr) {
+ done(errors::Aborted("Graph handle is not found: ", handle));
+ return;
+ }
+
+ const int num_units = item->units.size();
+ CHECK_GE(num_units, 1);
+
+ Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
+
+ // Sends values specified by the caller.
+ for (const auto& p : in) {
+ const string& key = p.first;
+ const Tensor& val = p.second;
+ const Status s = rendezvous->Send(key, Rendezvous::Args(), val, false);
+ if (!s.ok()) {
+ done(s);
+ item->Unref();
+ rendezvous->Unref();
+ return;
+ }
+ }
+
+ // Starts parallel Executors.
+ //
+ // NOTE: Transfer one ref of rendezvous and one ref of item to
+ // RunAllDone.
+ ExecutorBarrier* barrier = new ExecutorBarrier(
+ num_units, rendezvous, std::bind(&ME::RunAllDone, this, item, rendezvous,
+ out, done, std::placeholders::_1));
+ Executor::Args args;
+ {
+ mutex_lock l(mu_);
+ args.step_id = ++next_id_;
+ }
+ args.rendezvous = rendezvous;
+ args.cancellation_manager = cancellation_manager;
+ args.stats_collector = collector;
+ VLOG(1) << "Step " << args.step_id << " is for handle " << handle
+ << ", graph-local step " << step_id;
+ thread::ThreadPool* pool = worker_env_->compute_pool;
+ args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
+ for (const auto& unit : item->units) {
+ unit.root->RunAsync(args, barrier->Get());
+ }
+}
+
+void GraphMgr::RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
+ StatusCallback done, Status s) {
+ if (s.ok()) {
+ // Receives values requested by the caller.
+ for (auto& p : *out) {
+ const string& key = p.first;
+ Tensor* val = &p.second;
+ bool is_dead = false;
+ s = rendezvous->Recv(key, Rendezvous::Args(), val, &is_dead);
+ if (is_dead) {
+ s = errors::InvalidArgument("The tensor returned for ", key,
+ " was not valid.");
+ }
+ if (!s.ok()) break;
+ }
+ }
+ done(s);
+ rendezvous->Unref();
+ item->Unref();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
new file mode 100644
index 0000000000..4300dbe305
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -0,0 +1,147 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
+
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/config.pb.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class ExecutorOpts;
+class StepStatsCollector;
+
+// GraphMgr keeps track of a set of graphs that are registered with a
+// TensorFlow worker. Each registered graph is identified by a handle
+// that is generated by GraphMgr and returned to the caller.
+//
+// After a successful registration, the caller executes a graph using
+// the graph handle. Each execution is distinguished from others by a
+// caller generated global unique id "step_id". Multiple executions
+// can use the same graph concurrently and independently as long as
+// "step_id" used are different.
+//
+// Multiple threads can call GraphMgr methods concurrently.
+//
+// E.g.,
+// GraphMgr gmgr(worker_env);
+// string handle;
+// TF_CHECK_OK(gmgr.Register("session", { graph computes c = a + b },
+// &handle));
+// GraphMgr::NamedTensors in = { { "a", Tensor({1, 2}) },
+// { "b", Tensor({3, 4}) } };
+// GraphMgr::NamedTensors out = { { "c", Tensor() } };
+// TF_CHECK_OK(gmgr.Execute(handle, 0x0001, in, &out));
+// EXPECT_EQ(out["c"], Tensor({4, 6}));
+class GraphMgr {
+ public:
+ explicit GraphMgr(const WorkerEnv* worker_env);
+ ~GraphMgr();
+
+ // Registers a graph. Fills in "handle"
+ Status Register(const string& session, const GraphDef& gdef,
+ const GraphOptions& graph_options, string* handle);
+
+ // Executes one step of a registered graph "handle".
+ //
+ // If "out" is not nullptr, "out" specifies all keys the execution
+ // should receive upon finish.
+ typedef std::map<string, Tensor> NamedTensors;
+ typedef std::function<void(const Status&)> StatusCallback;
+ void ExecuteAsync(const string& handle, const int64 step_id,
+ const ExecutorOpts& opts, StepStatsCollector* collector,
+ CancellationManager* cancellation_manager,
+ const NamedTensors& in, NamedTensors* out,
+ StatusCallback done);
+
+ // Synchronous wrapper.
+ Status Execute(const string& handle, const int64 step_id,
+ const ExecutorOpts& opts,
+ StepStatsCollector* step_stats_collector,
+ CancellationManager* cancellation_manager,
+ const NamedTensors& in, NamedTensors* out);
+
+ // Deregisters a graph.
+ Status Deregister(const string& handle);
+
+ // Deregister all graphs.
+ Status DeregisterAll();
+
+ private:
+ typedef GraphMgr ME;
+
+ struct ExecutionUnit {
+ Device* device = nullptr;
+ Executor* root = nullptr;
+ FunctionLibraryRuntime* lib = nullptr;
+ };
+
+ struct Item : public core::RefCounted {
+ // TOOD(zhifengc): Keeps a copy of the original graph if the need arises.
+ // TOOD(zhifengc): Stats, updated by multiple runs potentially.
+ // TOOD(zhifengc): Dup-detection. Ensure step_id only run once.
+ ~Item() override;
+
+ // Session handle.
+ string session;
+
+ // Graph handle.
+ string handle;
+
+ // The definition of the library is shared by all partitions.
+ FunctionLibraryDefinition* lib_def = nullptr;
+
+ // A graph is partitioned over multiple devices. Each partition
+ // has a root executor which may call into the runtime library.
+ std::vector<ExecutionUnit> units;
+ };
+
+ // Not owned.
+ const WorkerEnv* worker_env_;
+
+ // Owned.
+ mutex mu_;
+ int64 next_id_ GUARDED_BY(mu_) = 0;
+
+ // Table mapping graph handles to registered graphs.
+ //
+ // TODO(zhifengc): If the client does not call Deregister, we'll
+ // lose memory over time. We should implement a timeout-based
+ // mechanism to gc these graphs.
+ std::unordered_map<string, Item*> table_;
+
+ void RunAllDone(Item* item, Rendezvous* rendezvous, NamedTensors* out,
+ StatusCallback done, Status run_status);
+
+ Status InitItem(const string& session, const GraphDef& gdef,
+ const GraphOptions& graph_options, Item* item);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_GRAPH_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
new file mode 100644
index 0000000000..2e8d6a1878
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -0,0 +1,413 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Master implements the service MasterSerivce.
+//
+// A Master maintains the state of live graph computation
+// sessions, each session orchestrates both local and remote devices
+// to carry out the graph computation.
+//
+// A Master knows ahead of time local devices available as
+// client devices.
+//
+// A Master discovers remote devices on-demand and keeps track of
+// statistics of those remote devices.
+//
+// Each session analyses the graph, places nodes across available
+// devices, and ultimately drives the graph computation by initiating
+// RunGraph on the workers.
+
+#include "tensorflow/core/distributed_runtime/master.h"
+
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/remote_device.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+Master::Master(MasterEnv* env, double session_gc_seconds)
+ : env_(env),
+ last_1000_steps_(1000),
+ step_count_(0),
+ session_gc_seconds_(session_gc_seconds) {
+ // Right now, a master service must be co-located with a device.
+ // Otherwise, fetches do not work.
+ CHECK(!env->local_devices.empty());
+
+ if (session_gc_seconds_ > 0.0) {
+ SchedClosure([this]() { GC(); });
+ }
+}
+
+Master::~Master() {
+ {
+ mutex_lock l(mu_);
+ shutdown_ = true;
+ shutdown_cv_.notify_all();
+ }
+ gc_stopped_.WaitForNotification();
+}
+
+void Master::GC() {
+ Env* env = Env::Default();
+ while (true) {
+ mutex_lock l(mu_);
+ const int kTimeoutMilliseconds = 10 * 1000; // 10 seconds.
+ WaitForMilliseconds(&l, &shutdown_cv_, kTimeoutMilliseconds);
+ if (shutdown_) {
+ break;
+ }
+ std::vector<string> handles;
+ const int64 num_micros = static_cast<int64>(session_gc_seconds_ * 1000000);
+ for (const auto& entry : sessions_) {
+ auto lat = entry.second->last_access_time_usec();
+ if (env->NowMicros() - lat > num_micros) {
+ handles.push_back(entry.first);
+ auto* sess = entry.second;
+ SchedClosure([this, sess]() {
+ LOG(WARNING) << "GC session " << sess->handle() << " after "
+ << session_gc_seconds_ << " seconds. "
+ << "Note that if you are starting multiple replicas "
+ << "on a staggered delay, session_gc_seconds may need "
+ << "to be raised.";
+ sess->Close();
+ });
+ }
+ }
+ for (const auto& handle : handles) sessions_.erase(handle);
+ }
+ gc_stopped_.Notify();
+}
+
+class DeviceFinder {
+ public:
+ explicit DeviceFinder(
+ const protobuf::RepeatedPtrField<string>& device_filters, MasterEnv* env)
+ : env_(env) {
+ auto process_filter = [this](const string& filter) {
+ DeviceNameUtils::ParsedName parsed;
+ if (DeviceNameUtils::ParseFullName(filter, &parsed)) {
+ filters_.push_back(parsed);
+ } else {
+ LOG(FATAL) << "Skipping invalid filter: " << filter;
+ }
+ };
+ for (const string& filter : device_filters) {
+ process_filter(filter);
+ }
+ }
+
+ ~DeviceFinder() {
+ for (Device* dev : found_) delete dev;
+ }
+
+ void Start() {
+ // Enumerates all known workers' target. A target name is a
+ // prefix of a device name. E.g., /job:mnist/replica:0/task:10.
+ std::vector<string> workers;
+ env_->worker_cache->ListWorkers(&workers);
+ std::vector<string> targets;
+ if (filters_.empty()) {
+ swap(workers, targets);
+ } else {
+ for (const string& name : workers) {
+ if (MatchFilters(name)) {
+ targets.push_back(name);
+ }
+ }
+ }
+ {
+ mutex_lock l(mu_);
+ num_pending_ = targets.size();
+ if (num_pending_ == 0) {
+ pending_zero_.notify_all();
+ }
+ }
+ // Talk to all workers to get the list of available devices.
+ using std::placeholders::_1;
+ using std::placeholders::_2;
+ for (size_t i = 0; i < targets.size(); ++i) {
+ NewRemoteDevices(env_->env, env_->worker_cache, targets[i],
+ std::bind(&ME::WhenFound, this, _1, _2));
+ }
+ }
+
+ void Wait() {
+ mutex_lock l(mu_);
+ while (num_pending_ != 0) {
+ pending_zero_.wait(l);
+ }
+ }
+
+ // The caller takes the ownership of returned remote devices.
+ void GetRemoteDevices(const std::vector<Device*>& local,
+ std::vector<Device*>* remote) {
+ std::unordered_set<string> names(local.size());
+ for (Device* dev : local) names.insert(dev->name());
+ mutex_lock l(mu_);
+ for (Device* dev : found_) {
+ const string& name = dev->name();
+ if (names.insert(name).second && MatchFilters(name)) {
+ remote->push_back(dev);
+ } else {
+ delete dev;
+ }
+ }
+ found_.clear();
+ }
+
+ private:
+ typedef DeviceFinder ME;
+ const MasterEnv* env_;
+ std::vector<DeviceNameUtils::ParsedName> filters_;
+
+ mutex mu_;
+ int num_pending_ GUARDED_BY(mu_);
+ condition_variable pending_zero_;
+ std::vector<Device*> found_ GUARDED_BY(mu_);
+
+ void WhenFound(const Status& s, std::vector<Device*>* devices) {
+ mutex_lock l(mu_);
+ if (!s.ok()) {
+ LOG(ERROR) << "Master init: " << s;
+ } else {
+ found_.insert(found_.end(), devices->begin(), devices->end());
+ devices->clear();
+ }
+ --num_pending_;
+ if (num_pending_ == 0) {
+ pending_zero_.notify_all();
+ }
+ }
+
+ // Returns true iff the set of devices allowed by 'x' intersects
+ // with the set of devices allowed by 'y'.
+ bool Intersects(const DeviceNameUtils::ParsedName& x,
+ const DeviceNameUtils::ParsedName& y) {
+ return (!x.has_job || !y.has_job || x.job == y.job) &&
+ (!x.has_replica || !y.has_replica || x.replica == y.replica) &&
+ (!x.has_task || !y.has_task || x.task == y.task) &&
+ (!x.has_type || !y.has_type || x.type == y.type) &&
+ (!x.has_id || !y.has_id || x.id == y.id);
+ }
+
+ // Returns true iff 'name' matches one of the filters_.
+ bool MatchFilters(const string& name) {
+ if (filters_.empty()) return true;
+ DeviceNameUtils::ParsedName x;
+ if (DeviceNameUtils::ParseFullName(name, &x)) {
+ for (const auto& filter : filters_) {
+ if (Intersects(x, filter)) return true;
+ }
+ }
+ return false;
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DeviceFinder);
+};
+
+void Master::CreateSession(const CreateSessionRequest* req,
+ CreateSessionResponse* resp, MyClosure done) {
+ SchedClosure([this, req, resp, done]() {
+ Status status = ValidateExternalGraphDefSyntax(req->graph_def());
+ if (status.ok()) {
+ // Ping all the workers and build the list of devices that the
+ // session will use.
+ DeviceFinder finder(req->config().device_filters(), env_);
+ finder.Start();
+ finder.Wait();
+ std::vector<Device*> remote_devices;
+ finder.GetRemoteDevices(env_->local_devices, &remote_devices);
+ SessionOptions options;
+ options.config = req->config();
+ MasterSessionInterface* session =
+ env_->master_session_factory(options, env_, &remote_devices);
+ GraphDef* gdef =
+ const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
+ Status create_status = session->Create(gdef);
+ if (!create_status.ok()) {
+ done(create_status);
+ return;
+ }
+ resp->set_session_handle(session->handle());
+ // Insert into the session map.
+ {
+ mutex_lock l(mu_);
+ CHECK(sessions_.insert({session->handle(), session}).second);
+ }
+ }
+ done(status);
+ });
+}
+
+void Master::ExtendSession(const ExtendSessionRequest* req,
+ ExtendSessionResponse* resp, MyClosure done) {
+ mu_.lock();
+ MasterSessionInterface* session = nullptr;
+ session = gtl::FindPtrOrNull(sessions_, req->session_handle());
+ if (session == nullptr) {
+ mu_.unlock();
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure([session, req, resp, done]() {
+ Status status = ValidateExternalGraphDefSyntax(req->graph_def());
+ if (status.ok()) {
+ status = session->Extend(req, resp);
+ }
+ done(status);
+ });
+ mu_.unlock();
+}
+
+void Master::RunStep(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp, MyClosure done) {
+ mu_.lock();
+ uint64 start_time = env_->env->NowMicros();
+ MasterSessionInterface* session =
+ gtl::FindPtrOrNull(sessions_, req->session_handle());
+ if (session == nullptr) {
+ mu_.unlock();
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure([this, start_time, session, opts, req, resp, done]() {
+ Status status = session->Run(opts, req, resp);
+ uint64 done_time = env_->env->NowMicros();
+ done(status);
+ mutex_lock l(mu_);
+ last_1000_steps_.AddValue((done_time - start_time) / 1e9);
+ ++step_count_;
+ });
+ mu_.unlock();
+}
+
+void Master::CloseSession(const CloseSessionRequest* req,
+ CloseSessionResponse* resp, MyClosure done) {
+ MasterSessionInterface* session = nullptr;
+ {
+ mu_.lock();
+ auto iter = sessions_.find(req->session_handle());
+ if (iter == sessions_.end()) {
+ mu_.unlock();
+ done(errors::Aborted(
+ "Session ", req->session_handle(),
+ " is not found. Possibly, this master has restarted."));
+ return;
+ }
+ session = iter->second;
+ sessions_.erase(iter);
+ mu_.unlock();
+ }
+
+ // Session Close() blocks on thread shutdown. Therefore, we need to
+ // delete it in non-critical thread.
+ SchedClosure([session, done]() {
+ Status s = session->Close();
+ done(s);
+ });
+}
+
+void Master::ListDevices(const ListDevicesRequest* req,
+ ListDevicesResponse* resp, MyClosure done) {
+ SchedClosure([this, req, resp, done]() {
+ DeviceFinder finder({}, env_);
+ finder.Start();
+ finder.Wait();
+ std::vector<Device*> remote_devices;
+ finder.GetRemoteDevices(env_->local_devices, &remote_devices);
+ for (Device* dev : env_->local_devices) {
+ *(resp->add_local_device()) = dev->attributes();
+ }
+ for (Device* dev : remote_devices) {
+ *(resp->add_remote_device()) = dev->attributes();
+ delete dev;
+ }
+ done(Status::OK());
+ });
+}
+
+void Master::CleanupWorkers(const ResetRequest& reset) {
+ std::vector<string> worker_names;
+ env_->worker_cache->ListWorkers(&worker_names);
+ if (!worker_names.empty()) {
+ const int num_workers = worker_names.size();
+ std::vector<Notification> n(num_workers);
+ CleanupAllRequest req;
+ (*req.mutable_container()) = reset.container();
+ std::vector<CleanupAllResponse> resp(num_workers);
+ int c = 0;
+ for (int i = 0; i < num_workers; ++i) {
+ auto worker = env_->worker_cache->CreateWorker(worker_names[i]);
+ if (worker) {
+ worker->CleanupAllAsync(&req, &resp[i], [&n, worker, c](Status s) {
+ TF_CHECK_OK(s);
+ delete worker;
+ n[c].Notify();
+ });
+ } else {
+ n[c].Notify();
+ }
+ ++c;
+ }
+ for (int i = 0; i < n.size(); ++i) {
+ n[i].WaitForNotification();
+ }
+ }
+}
+
+void Master::Reset(const ResetRequest* req, ResetResponse* resp,
+ MyClosure done) {
+ // Vector to hold the session pointers present in the sessions_
+ // (string->Session*) map.
+ std::vector<MasterSessionInterface*> sessions;
+ {
+ mutex_lock l(mu_);
+ for (const auto& entry : sessions_) {
+ sessions.push_back(entry.second);
+ }
+ sessions_.clear();
+ }
+
+ CleanupWorkers(*req);
+
+ SchedClosure([sessions, done]() {
+ Status s;
+ for (MasterSessionInterface* session : sessions) {
+ s.Update(session->Close());
+ }
+ done(s);
+ });
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h
new file mode 100644
index 0000000000..16e2c1a866
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master.h
@@ -0,0 +1,98 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
+
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/distributed_runtime/master_env.h"
+#include "tensorflow/core/distributed_runtime/master_session_interface.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+class Master {
+ public:
+ explicit Master(MasterEnv* env, double session_gc_seconds);
+ virtual ~Master();
+
+ // Convenient typedef for a closure passing a Status.
+ typedef std::function<void(const Status&)> MyClosure;
+
+ void CreateSession(const CreateSessionRequest* req,
+ CreateSessionResponse* resp, MyClosure done);
+
+ void ExtendSession(const ExtendSessionRequest* req,
+ ExtendSessionResponse* resp, MyClosure done);
+
+ void RunStep(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp, MyClosure done);
+
+ void CloseSession(const CloseSessionRequest* req, CloseSessionResponse* resp,
+ MyClosure done);
+
+ void ListDevices(const ListDevicesRequest* req, ListDevicesResponse* resp,
+ MyClosure done);
+
+ void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
+
+ private:
+ typedef Master ME;
+
+ // Not owned.
+ MasterEnv* env_ = nullptr;
+
+ // Owned.
+ mutex mu_;
+
+ // shutdown_ is set to true by the dtor.
+ condition_variable shutdown_cv_;
+ bool shutdown_ GUARDED_BY(mu_) = false;
+ Notification gc_stopped_;
+
+ // Maps session handles to sessions.
+ std::unordered_map<string, MasterSessionInterface*> sessions_ GUARDED_BY(mu_);
+
+ // Moving average of step times.
+ MovingAverage last_1000_steps_ GUARDED_BY(mu_);
+
+ // Cumulative number of steps executed.
+ int64 step_count_ GUARDED_BY(mu_);
+
+ // If a session is not active for this many seconds, it will be
+ // closed automatically.
+ const double session_gc_seconds_;
+
+ // Call CleanupAll on all workers.
+ void CleanupWorkers(const ResetRequest& reset);
+
+ // Cleanup unused session.
+ void GC();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Master);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
new file mode 100644
index 0000000000..513442b7e6
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -0,0 +1,66 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
+
+#include <functional>
+#include <vector>
+
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+class Device;
+class Env;
+class MasterSessionInterface;
+class OpRegistryInterface;
+class WorkerCacheInterface;
+
+// The master environment class, which holds a bag of pointers to
+// per-master state.
+//
+// MasterEnv does not own its member pointers.
+struct MasterEnv {
+ Env* env = nullptr;
+
+ // Object from which WorkerInterface instances can be obtained.
+ WorkerCacheInterface* worker_cache = nullptr;
+
+ // The operation definitions to use. Must be filled before use.
+ const OpRegistryInterface* ops = nullptr;
+
+ // Local devices co-located with this master. Devices are not owned
+ // by the master service.
+ //
+ // REQUIRES: !local_devices.empty().
+ std::vector<Device*> local_devices;
+
+ // Factory for creating master sessions, given session options and a
+ // vector of devices.
+ //
+ // The caller of the function takes ownership of the returned
+ // `MasterSessionInterface`, which may not be null. Ownership of the
+ // `MasterEnv*` is retained by the caller. The callee takes
+ // ownership of the `std::vector<Device*>*` argument, but does not
+ // take ownership of the `Device*` objects in the vector.
+ std::function<MasterSessionInterface*(const SessionOptions&, MasterEnv*,
+ std::vector<Device*>*)>
+ master_session_factory;
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h
new file mode 100644
index 0000000000..602cfbd8a3
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_interface.h
@@ -0,0 +1,52 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+
+namespace tensorflow {
+
+// Pure virtual interface for communicating with the TensorFlow Master service.
+//
+// This interface is intended to support in-process master
+// implementations that do not require an RPC roundtrip.
+class MasterInterface {
+ public:
+ virtual ~MasterInterface() {}
+ virtual Status CreateSession(const CreateSessionRequest* request,
+ CreateSessionResponse* response) = 0;
+
+ virtual Status ExtendSession(const ExtendSessionRequest* request,
+ ExtendSessionResponse* response) = 0;
+
+ virtual Status RunStep(const RunStepRequest* request,
+ RunStepResponse* response) = 0;
+
+ virtual Status CloseSession(const CloseSessionRequest* request,
+ CloseSessionResponse* response) = 0;
+
+ virtual Status ListDevices(const ListDevicesRequest* request,
+ ListDevicesResponse* response) = 0;
+
+ virtual Status Reset(const ResetRequest* request,
+ ResetResponse* response) = 0;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
new file mode 100644
index 0000000000..9535f5db47
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -0,0 +1,942 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/master_session.h"
+
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/distributed_runtime/master_env.h"
+#include "tensorflow/core/distributed_runtime/master_session_interface.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/graph/tensor_id.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+namespace {
+// A little bit of per-step state.
+struct PerStepState {
+ Microseconds start_micros = Microseconds(0);
+ Microseconds end_micros = Microseconds(0);
+ std::vector<StepStats> step_stats; // per partition
+};
+
+// A session encapsulates a graph computation (resource allocation,
+// placement, execution, etc.).
+class MasterSession : public MasterSessionInterface {
+ public:
+ // This session encapsulates the graph computation for a graph.
+ //
+ // The session places nodes on devices in "remote_devs" and executes
+ // operations on these devices.
+ //
+ // The caller takes ownership of all remote devices.
+ MasterSession(const SessionOptions& options, const MasterEnv* env,
+ std::vector<Device*>* remote_devs);
+
+ // Initialize the Session for "def". Must be called before Extend(),
+ // Run(), or Close().
+ //
+ // The callee may clear "def".
+ Status Create(GraphDef* def) override;
+
+ // Returns the session handle.
+ const string& handle() const override { return handle_; }
+
+ // Returns the last access time (the number of micro-seconds since
+ // some fixed point in time) of this session.
+ uint64 last_access_time_usec() const override {
+ return last_access_time_usec_.load();
+ }
+
+ // Attempt to extend the graph according to the given "req".
+ // (See master.proto for details of valid extensions.)
+ //
+ // PRECONDITION: The current version of this session's graph
+ // is "req->current_graph_version".
+ //
+ // POSTCONDITION: The current version of this session's graph
+ // is "resp->new_graph_version".
+ //
+ // Extend() may block the caller thread for a long time.
+ Status Extend(const ExtendSessionRequest* req,
+ ExtendSessionResponse* resp) override;
+
+ // Run one step.
+ Status Run(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp) override;
+
+ // Close this session and delete "*this". Returns OK if all known
+ // states are cleanup successfully.
+ //
+ // Close() may block the caller thread for a long time.
+ Status Close() override;
+
+ private:
+ SessionOptions session_opts_;
+
+ // Not owned.
+ const MasterEnv* env_;
+
+ // The opaque session handle.
+ const string handle_;
+
+ // Owned.
+ std::vector<Device*> remote_devs_;
+
+ // The device set used by this session.
+ DeviceSet devices_;
+
+ // TODO(zhifengc): Support Extend().
+ //
+ // 'func_def_lib_' is a copy of the initial graph def's library.
+ // 'flib_def_' is an index structure of "func_def_lib_' keyed by
+ // function names.
+ FunctionDefLibrary func_def_lib_;
+ FunctionLibraryDefinition* flib_def_ = nullptr;
+
+ std::atomic_ulong last_access_time_usec_;
+
+ mutex mu_;
+ std::unique_ptr<SimpleGraphExecutionState> execution_state_;
+ int64 graph_version_;
+
+ int32 steps_since_last_scheduling_ GUARDED_BY(mu_) = 0;
+ int32 scheduling_period_steps_ GUARDED_BY(mu_) = 10;
+
+ // We keep a map from a signature of a run request to the
+ // ReffedClientGraph the can execute it. We keep up to one old copy
+ // of each ReffedClientGraph around because if it gets deallocated
+ // before a new substitute has been created, Variables can go out of
+ // scope and lose their state.
+ class ReffedClientGraph;
+ typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
+ RCGMap runs_ GUARDED_BY(mu_);
+ RCGMap obsolete_ GUARDED_BY(mu_);
+
+ // Active RunStep calls.
+ condition_variable num_running_is_zero_;
+ int32 num_running_ GUARDED_BY(mu_) = 0;
+
+ std::unordered_map<uint64, int64> subgraph_execution_counts_ GUARDED_BY(mu_);
+
+ // We need to ensure that certain nodes added (e.g., send and recv
+ // nodes) are unique across all sub-graphs within this session.
+ int64 next_node_id_ GUARDED_BY(mu_) = 0;
+
+ // Private dtor. The client must call Close().
+ virtual ~MasterSession();
+
+ Status StartStep(const RunStepRequest& req, BuildGraphOptions* opts,
+ int64* count, ReffedClientGraph** graph);
+ void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
+ RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp);
+ void UpdateLastAccessTime();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MasterSession);
+};
+
+// Session wraps ClientGraph in a reference counted object. This way,
+// Session can clear up the cache mapping Run requests to compiled
+// graphs while the compiled graph is still being used.
+//
+// TODO(zhifengc): Cleanup this class. It's becoming messy.
+class MasterSession::ReffedClientGraph : public core::RefCounted {
+ public:
+ ReffedClientGraph(const string& handle, const BuildGraphOptions& bopts,
+ ClientGraph* cg, const GraphOptions& graph_opts)
+ : session_handle_(handle),
+ client_graph_(cg),
+ bopts_(bopts),
+ graph_opts_(graph_opts) {
+ VLOG(1) << "Created ReffedClientGraph for node with "
+ << client_graph_->graph.num_node_ids();
+
+ const string key =
+ strings::StrCat("{", str_util::Join(bopts.feed_endpoints, ","), "},{",
+ str_util::Join(bopts.target_nodes, ","), "},{",
+ str_util::Join(bopts.fetch_endpoints, ","), "}");
+ // TODO(mrry): Publish information about the graph (such as
+ // timelines, the pruned graph, statistics, etc.).
+ }
+
+ ~ReffedClientGraph() override {
+ delete client_graph_;
+ DeregisterPartitions();
+ }
+
+ const ClientGraph* client_graph() { return client_graph_; }
+
+ // Local execution methods.
+
+ // Partitions the graph into subgraphs and registers them on
+ // workers.
+ Status RegisterPartitions(const MasterEnv* env, const PartitionOptions& popts,
+ const FunctionDefLibrary& func_def_lib);
+
+ // Runs one step of all partitions.
+ Status RunPartitions(const MasterEnv* env, int64 step_id,
+ int64 execution_count,
+ SimpleGraphExecutionState* execution_state,
+ PerStepState* pss, CallOptions* opts,
+ const RunStepRequest& req, RunStepResponse* resp);
+
+ // Calls workers to cleanup states for the step "step_id". Waits
+ // till all cleanup rpcs complete.
+ Status CleanupPartitions(int64 step_id);
+
+ // TODO(mrry): Runtime statistics collection.
+
+ private:
+ const string session_handle_;
+ ClientGraph* const client_graph_ = nullptr;
+ std::unordered_set<const Node*> nodes_needing_input_mapping_;
+ BuildGraphOptions bopts_;
+ const GraphOptions graph_opts_;
+
+ // Graph partitioned into per-location subgraphs.
+ struct Part {
+ // Worker name.
+ string name;
+
+ // Graph definition.
+ GraphDef gdef;
+
+ // Maps feed names to rendezvous keys. Empty most of the time.
+ std::unordered_map<string, string> feed_key;
+
+ // Maps rendezvous keys to fetch names. Empty most of the time.
+ std::unordered_map<string, string> key_fetch;
+
+ // The interface to the worker. Owned.
+ WorkerInterface* worker = nullptr;
+
+ // After registeration with the worker, graph_handle identifies
+ // this partition on the worker.
+ string graph_handle;
+
+ Part() : feed_key(3), key_fetch(3) {}
+ };
+
+ // partitions_ is immutable after RegisterPartitions() call
+ // finishes. RunPartitions() can access partitions_ safely without
+ // acquring locks.
+ std::vector<Part> partitions_;
+
+ mutable mutex mu_;
+
+ // Partition initialization and registration only needs to happen
+ // once. init_started_ && !init_done_ indicates the initialization
+ // is on going.
+ bool init_started_ GUARDED_BY(mu_) = false;
+ Notification init_done_;
+
+ // init_result_ remembers the initialization error if any.
+ Status init_result_ GUARDED_BY(mu_);
+
+ // Send/Recv nodes that are the result of client-added
+ // feeds and fetches must be tracked so that the tensors
+ // can be be added to the local rendezvous.
+ static void TrackFeedsAndFetches(Part* part, const PartitionOptions& popts);
+
+ // The actual graph partitioning and registration implementation.
+ Status DoRegisterPartitions(const MasterEnv* env,
+ const PartitionOptions& popts,
+ const FunctionDefLibrary& func_def_lib);
+
+ // Deregisters the partitions on the workers. Called in the
+ // destructor and does not wait for the rpc completion.
+ void DeregisterPartitions();
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ReffedClientGraph);
+};
+
+Status MasterSession::ReffedClientGraph::RegisterPartitions(
+ const MasterEnv* env, const PartitionOptions& popts,
+ const FunctionDefLibrary& func_def_lib) {
+ { // Ensure register once.
+ mu_.lock();
+ if (!init_started_) {
+ init_started_ = true;
+ mu_.unlock();
+ Status s = DoRegisterPartitions(env, popts, func_def_lib);
+ mu_.lock();
+ init_result_ = s;
+ init_done_.Notify();
+ } else {
+ mu_.unlock();
+ init_done_.WaitForNotification();
+ mu_.lock();
+ }
+ Status result = init_result_;
+ mu_.unlock();
+ return result;
+ }
+}
+
+static string SplitByWorker(const Node* node) {
+ string task;
+ string device;
+ CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
+ &device))
+ << "node: " << node->name() << " dev: " << node->assigned_device_name();
+ return task;
+}
+
+void MasterSession::ReffedClientGraph::TrackFeedsAndFetches(
+ Part* part, const PartitionOptions& popts) {
+ for (int i = 0; i < part->gdef.node_size(); ++i) {
+ NodeDef* ndef = part->gdef.mutable_node(i);
+ const bool is_recv = ndef->op() == "_Recv";
+ const bool is_send = ndef->op() == "_Send";
+
+ if (is_recv || is_send) {
+ string name;
+ TF_CHECK_OK(GetNodeAttr(*ndef, "tensor_name", &name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(*ndef, "send_device", &send_device));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(*ndef, "recv_device", &recv_device));
+ uint64 send_device_incarnation;
+ TF_CHECK_OK(
+ GetNodeAttr(*ndef, "send_device_incarnation",
+ reinterpret_cast<int64*>(&send_device_incarnation)));
+ const string& key =
+ Rendezvous::CreateKey(send_device, send_device_incarnation,
+ recv_device, name, FrameAndIter(0, 0));
+
+ // Only send/recv nodes that were added as feeds and fetches
+ // (client-terminated) should be tracked. Other send/recv nodes
+ // are for transferring data between partitions / memory spaces.
+ bool client_terminated;
+ TF_CHECK_OK(GetNodeAttr(*ndef, "client_terminated", &client_terminated));
+ if (client_terminated) {
+ if (is_recv) {
+ part->feed_key.insert({name, key});
+ } else {
+ part->key_fetch.insert({key, name});
+ }
+ }
+ }
+ }
+}
+
+Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
+ const MasterEnv* env, const PartitionOptions& popts,
+ const FunctionDefLibrary& func_def_lib) {
+ // Partition the graph.
+ Status s;
+ std::unordered_map<string, GraphDef> graph_partitions;
+ s = Partition(popts, &client_graph_->graph, &graph_partitions);
+ if (!s.ok()) return s;
+ partitions_.reserve(graph_partitions.size());
+ for (auto& name_def : graph_partitions) {
+ partitions_.resize(partitions_.size() + 1);
+ Part* part = &partitions_.back();
+ part->name = name_def.first;
+ part->gdef.Swap(&name_def.second);
+ // For simplicity, we ship the library completely to every worker.
+ *(part->gdef.mutable_library()) = func_def_lib;
+ TrackFeedsAndFetches(part, popts);
+ part->worker = env->worker_cache->CreateWorker(part->name);
+ if (part->worker == nullptr) {
+ s = errors::NotFound("worker ", part->name);
+ break;
+ }
+ }
+ if (!s.ok()) {
+ for (Part& part : partitions_) {
+ delete part.worker;
+ }
+ return s;
+ }
+ struct Call {
+ RegisterGraphRequest req;
+ RegisterGraphResponse resp;
+ Status status;
+ Notification done;
+ };
+ const int num = partitions_.size();
+ gtl::InlinedVector<Call, 4> calls(num);
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ Call* c = &calls[i];
+ c->req.set_session_handle(session_handle_);
+ *c->req.mutable_graph_def() = part.gdef;
+ *c->req.mutable_graph_options() = graph_opts_;
+ VLOG(2) << "Register " << part.gdef.DebugString();
+ auto cb = [c](const Status& s) {
+ c->status = s;
+ c->done.Notify();
+ };
+ part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
+ }
+ for (int i = num - 1; i >= 0; --i) {
+ Call* c = &calls[i];
+ c->done.WaitForNotification();
+ s.Update(c->status);
+ partitions_[i].graph_handle = c->resp.graph_handle();
+ }
+ return s;
+}
+
+static bool CopyIfNeeded(TensorProto* in, TensorProto* out) {
+ if (in->tensor_content().empty()) {
+ // If the tensor is not encoded in tensor_content or contains 0
+ // elements, we can return it to the client directly.
+ out->Swap(in);
+ } else {
+ Tensor t(in->dtype());
+ if (!t.FromProto(cpu_allocator(), *in)) return false;
+ t.AsProtoField(out);
+ }
+ return true;
+}
+
+// Helper class to manage "num" parallel RunGraph calls.
+class RunManyGraphs {
+ public:
+ explicit RunManyGraphs(int num) : calls_(num), num_pending_(num) {}
+
+ ~RunManyGraphs() {}
+
+ // Returns the index-th call.
+ struct Call {
+ CallOptions opts;
+ RunGraphRequest req;
+ RunGraphResponse resp;
+ };
+ Call* get(int index) { return &calls_[index]; }
+
+ // When the index-th call is done, updates the overall status.
+ void WhenDone(int index, const Status& s) {
+ TRACEPRINTF("Partition %d %s", index, s.ToString().c_str());
+ {
+ mutex_lock l(mu_);
+ if (!s.ok()) {
+ UpdateStatusLocked(s);
+ }
+ --num_pending_;
+ cv_pending_.notify_all();
+ }
+ }
+
+ void StartCancel() {
+ mutex_lock l(mu_);
+ UpdateStatusLocked(errors::Cancelled("RunManyGraphs"));
+ }
+
+ void Wait() {
+ mutex_lock l(mu_);
+ while (num_pending_ > 0) {
+ cv_pending_.wait(l);
+ }
+ }
+
+ Status status() const {
+ mutex_lock l(mu_);
+ return status_;
+ }
+
+ private:
+ gtl::InlinedVector<Call, 4> calls_;
+
+ // TODO(jeff,sanjay): Replace bookkeeping state here with a
+ // BlockingCounter abstraction that we define in
+ // tensorflow/core/lib/core.
+ mutable mutex mu_;
+ condition_variable cv_pending_;
+ int num_pending_;
+ Status status_ GUARDED_BY(mu_);
+
+ void UpdateStatusLocked(const Status& s) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (status_.ok()) {
+ status_ = s;
+ for (Call& call : calls_) {
+ call.opts.StartCancel();
+ }
+ }
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
+};
+
+Status MasterSession::ReffedClientGraph::RunPartitions(
+ const MasterEnv* env, int64 step_id, int64 execution_count,
+ SimpleGraphExecutionState* execution_state, PerStepState* pss,
+ CallOptions* call_opts, const RunStepRequest& req, RunStepResponse* resp) {
+ VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+ << execution_count;
+ // Builds an index for feeds provided by the client.
+ std::unordered_map<StringPiece, const TensorProto*, StringPiece::Hasher>
+ feeds(3);
+
+ for (const auto& feed : req.feed()) {
+ if (!feeds.insert({feed.name(), &feed.tensor()}).second) {
+ return errors::InvalidArgument("Duplicated feeds: ", feed.name());
+ }
+ }
+
+ // Prepares a number of calls to workers. One call per partition.
+ ExecutorOpts exec_opts;
+ const int num = partitions_.size();
+ RunManyGraphs calls(num);
+
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ RunManyGraphs::Call* c = calls.get(i);
+ c->req.set_graph_handle(part.graph_handle);
+ c->req.set_step_id(step_id);
+ *c->req.mutable_exec_opts() = exec_opts;
+ // If any feeds are provided, send the feed values together
+ // in the RunGraph request.
+ for (const auto& feed_key : part.feed_key) {
+ const string& feed = feed_key.first;
+ const string& key = feed_key.second;
+ const TensorProto* val = feeds[feed];
+ if (val == nullptr) {
+ return errors::InvalidArgument("No feed is provided for feed=", feed,
+ ", key=", key);
+ }
+ auto* send = c->req.add_send();
+ send->set_key(key);
+ *(send->mutable_val()) = *val; // TODO(mrry): make it faster if needed.
+ }
+ for (const auto& key_fetch : part.key_fetch) {
+ const string& key = key_fetch.first;
+ c->req.add_recv_key(key);
+ }
+ }
+
+ // Issues RunGraph calls.
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ RunManyGraphs::Call* call = calls.get(i);
+ TRACEPRINTF("Partition %d %s", i, part.name.c_str());
+ part.worker->RunGraphAsync(
+ &call->opts, &call->req, &call->resp,
+ std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
+ }
+
+ // Waits for the RunGraph calls.
+ call_opts->SetCancelCallback([&calls]() { calls.StartCancel(); });
+ calls.Wait();
+ call_opts->ClearCancelCallback();
+
+ // Collects fetches.
+ Status status = calls.status();
+ if (status.ok()) {
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ for (auto& recv : *(calls.get(i)->resp.mutable_recv())) {
+ auto* ret = resp->add_tensor();
+ auto iter = part.key_fetch.find(recv.key());
+ if (iter == part.key_fetch.end()) {
+ status.Update(errors::Internal("Unexpected fetch key: ", recv.key()));
+ break;
+ }
+ const string& fetch = iter->second;
+ ret->set_name(fetch);
+ if (!CopyIfNeeded(recv.mutable_val(), ret->mutable_tensor())) {
+ status.Update(
+ errors::Internal("Unexpected unparseable tensor: ", recv.key()));
+ break;
+ }
+ }
+ if (calls.get(i)->resp.has_step_stats()) {
+ pss->step_stats[i].Swap(calls.get(i)->resp.mutable_step_stats());
+ }
+ }
+ }
+ return status;
+}
+
+Status MasterSession::ReffedClientGraph::CleanupPartitions(int64 step_id) {
+ struct Call {
+ CleanupGraphRequest req;
+ CleanupGraphResponse resp;
+ Notification done;
+ Status status;
+ };
+ const int num = partitions_.size();
+ gtl::InlinedVector<Call, 4> calls(num);
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ Call* c = &calls[i];
+ c->req.set_step_id(step_id);
+ part.worker->CleanupGraphAsync(&c->req, &c->resp, [c](const Status& s) {
+ c->status = s;
+ c->done.Notify();
+ });
+ }
+ Status s;
+ for (int i = num - 1; i >= 0; --i) {
+ Call* c = &calls[i];
+ c->done.WaitForNotification();
+ s.Update(c->status);
+ }
+ return s;
+}
+
+// Makes async calls to workers without waiting deregistering subgraphs.
+void MasterSession::ReffedClientGraph::DeregisterPartitions() {
+ struct Call {
+ DeregisterGraphRequest req;
+ DeregisterGraphResponse resp;
+ };
+ for (Part& part : partitions_) {
+ Call* c = new Call;
+ c->req.set_graph_handle(part.graph_handle);
+ WorkerInterface* w = part.worker;
+ auto cb = [c, w](const Status& s) {
+ if (!s.ok()) {
+ // This error is potentially benign, so we don't log at the
+ // error level.
+ LOG(INFO) << "DeregisterGraph error: " << s;
+ }
+ delete c;
+ delete w;
+ };
+ w->DeregisterGraphAsync(&c->req, &c->resp, cb);
+ }
+}
+
+void BuildBuildGraphOptions(const RunStepRequest& req,
+ BuildGraphOptions* opts) {
+ for (const auto& feed : req.feed()) {
+ opts->feed_endpoints.push_back(feed.name());
+ }
+ for (const auto& fetch : req.fetch()) {
+ // TODO(touts): handle ref:
+ opts->fetch_endpoints.push_back(fetch);
+ }
+ for (const auto& target : req.target()) {
+ opts->target_nodes.push_back(target);
+ }
+
+ std::sort(opts->feed_endpoints.begin(), opts->feed_endpoints.end());
+ std::sort(opts->target_nodes.begin(), opts->target_nodes.end());
+ std::sort(opts->fetch_endpoints.begin(), opts->fetch_endpoints.end());
+}
+
+uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
+ uint64 h = 0x2b992ddfa23249d6ull;
+ for (const string& name : opts.feed_endpoints) {
+ h = Hash64(name.c_str(), name.size(), h);
+ }
+ for (const string& name : opts.target_nodes) {
+ h = Hash64(name.c_str(), name.size(), h);
+ }
+ for (const string& name : opts.fetch_endpoints) {
+ h = Hash64(name.c_str(), name.size(), h);
+ }
+ return h;
+}
+
+string BuildGraphOptionsString(const BuildGraphOptions& opts) {
+ string buf;
+ for (const string& name : opts.feed_endpoints) {
+ strings::StrAppend(&buf, " FdE: ", name);
+ }
+ strings::StrAppend(&buf, "\n");
+ for (const string& name : opts.target_nodes) {
+ strings::StrAppend(&buf, " TN: ", name);
+ }
+ strings::StrAppend(&buf, "\n");
+ for (const string& name : opts.fetch_endpoints) {
+ strings::StrAppend(&buf, " FeE: ", name);
+ }
+ strings::StrAppend(&buf, "\n");
+ return buf;
+}
+
+MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
+ std::vector<Device*>* remote_devs)
+ : session_opts_(opt),
+ env_(env),
+ handle_(strings::FpToString(random::New64())),
+ graph_version_(0),
+ runs_(5) {
+ UpdateLastAccessTime();
+
+ swap(remote_devs_, *remote_devs);
+ VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
+ << " #remote " << remote_devs_.size();
+ for (Device* d : remote_devs_) {
+ devices_.AddDevice(d);
+ }
+ int num_local_devices = 0;
+ for (Device* d : env->local_devices) {
+ devices_.AddDevice(d);
+ if (num_local_devices == 0) {
+ // Uses the first local device as the client device.
+ devices_.set_client_device(d);
+ }
+ num_local_devices++;
+ }
+}
+
+MasterSession::~MasterSession() {
+ for (const auto& iter : runs_) iter.second->Unref();
+ for (const auto& iter : obsolete_) iter.second->Unref();
+ delete flib_def_;
+ for (Device* dev : remote_devs_) delete dev;
+}
+
+void MasterSession::UpdateLastAccessTime() {
+ last_access_time_usec_.store(Env::Default()->NowMicros());
+}
+
+Status MasterSession::Create(GraphDef* graph_def) {
+ // Keeps a copy of graph_def->library() and flib_def_ serves the
+ // OpRegistryInterface used by the SimpleGraphExecutionState to construct the
+ // pre-partitioned graphs during DoRunWithLocalExecution().
+ func_def_lib_.Swap(graph_def->mutable_library());
+ flib_def_ = new FunctionLibraryDefinition(func_def_lib_);
+
+ SimpleGraphExecutionStateOptions options;
+ options.device_set = &devices_;
+ options.session_options = &session_opts_;
+ execution_state_.reset(new SimpleGraphExecutionState(flib_def_, options));
+ TF_RETURN_IF_ERROR(execution_state_->Create(graph_def));
+
+ return Status::OK();
+}
+
+Status MasterSession::Extend(const ExtendSessionRequest* req,
+ ExtendSessionResponse* resp) {
+ UpdateLastAccessTime();
+ std::unique_ptr<SimpleGraphExecutionState> old_execution_state;
+ {
+ mutex_lock l(mu_);
+ // TODO(mrry): Redesign the locking with reader/writer locks to prevent
+ // starvation due to concurrent steps being issued. This is not
+ // immediately important because we expect Extend to be used in
+ // development/interactive exploration, and not during high-throughput
+ // training.
+ while (num_running_ != 0) {
+ num_running_is_zero_.wait(l);
+ }
+
+ if (graph_version_ != req->current_graph_version()) {
+ return errors::Aborted("Current version is ", graph_version_,
+ " but caller expected ",
+ req->current_graph_version(), ".");
+ }
+
+ CHECK(execution_state_);
+ SimpleGraphExecutionState* extended_execution_state = nullptr;
+ Status s =
+ execution_state_->Extend(req->graph_def(), &extended_execution_state);
+ if (s.ok()) {
+ CHECK(extended_execution_state);
+ old_execution_state =
+ std::move(execution_state_); // Will be released outside the lock
+ execution_state_.reset(extended_execution_state);
+ ++graph_version_;
+ resp->set_new_graph_version(graph_version_);
+ }
+
+ return s;
+ }
+}
+
+Status MasterSession::StartStep(const RunStepRequest& req,
+ BuildGraphOptions* opts, int64* count,
+ ReffedClientGraph** rcg) {
+ BuildBuildGraphOptions(req, opts);
+ const uint64 hash = HashBuildGraphOptions(*opts);
+ ReffedClientGraph* to_unref = nullptr;
+ {
+ mutex_lock l(mu_);
+ // Keep track of how many times this subgraph has been executed in
+ // this session.
+ int64* c = &subgraph_execution_counts_[hash];
+ *count = (*c)++;
+ auto iter = runs_.find(hash);
+ if (iter == runs_.end()) {
+ // We have not seen this subgraph before. Build the subgraph and
+ // cache it.
+ VLOG(1) << "Unseen hash " << hash << " for "
+ << BuildGraphOptionsString(*opts);
+ ClientGraph* client_graph = nullptr;
+ TF_RETURN_IF_ERROR(execution_state_->BuildGraph(*opts, &client_graph));
+ auto entry = new ReffedClientGraph(handle_, *opts, client_graph,
+ session_opts_.config.graph_options());
+ iter = runs_.insert({hash, entry}).first;
+ auto obs_iter = obsolete_.find(hash);
+ if (obs_iter != obsolete_.end()) {
+ to_unref = obs_iter->second;
+ obsolete_.erase(obs_iter);
+ }
+ VLOG(1) << "Preparing to execute new graph";
+ }
+ *rcg = iter->second;
+ (*rcg)->Ref();
+ }
+ if (to_unref) to_unref->Unref();
+ return Status::OK();
+}
+
+void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
+ RCGMap* rcg_map) {
+ VLOG(1) << "Discarding all reffed graphs";
+ for (auto p : *rcg_map) {
+ ReffedClientGraph* rcg = p.second;
+ if (to_unref) {
+ to_unref->push_back(rcg);
+ } else {
+ rcg->Unref();
+ }
+ }
+ rcg_map->clear();
+}
+
+Status MasterSession::Run(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp) {
+ UpdateLastAccessTime();
+ {
+ mutex_lock l(mu_);
+ ++num_running_;
+ }
+ Status status = DoRunWithLocalExecution(opts, req, resp);
+ {
+ mutex_lock l(mu_);
+ --num_running_;
+ if (num_running_ == 0) {
+ num_running_is_zero_.notify_all();
+ }
+ }
+ return status;
+}
+
+Status MasterSession::DoRunWithLocalExecution(CallOptions* opts,
+ const RunStepRequest* req,
+ RunStepResponse* resp) {
+ VLOG(2) << "DoRunWithLocalExecution "
+ << "req: " << req->DebugString();
+ PerStepState pss;
+ pss.start_micros = Env::Default()->NowMicros();
+
+ // Prepare.
+ BuildGraphOptions bgopts;
+ ReffedClientGraph* rcg = nullptr;
+ int64 count = 0;
+ TF_RETURN_IF_ERROR(StartStep(*req, &bgopts, &count, &rcg));
+
+ // Unref "rcg" when out of scope.
+ core::ScopedUnref unref(rcg);
+
+ // Registers subgraphs if haven't done so.
+ PartitionOptions popts;
+ popts.node_to_loc = SplitByWorker;
+ popts.new_name = [this](const string& prefix) {
+ mutex_lock l(mu_);
+ return strings::StrCat(prefix, "_S", next_node_id_++);
+ };
+ popts.get_incarnation = [this](const string& name) {
+ Device* d = devices_.FindDeviceByName(name);
+ if (d == nullptr) {
+ return PartitionOptions::kIllegalIncarnation;
+ } else {
+ return d->attributes().incarnation();
+ }
+ };
+ popts.control_flow_added = false;
+ // TODO(mrry): Enable DT_BFLOAT16 casting.
+ // TODO(mrry): Enable recv scheduling.
+ TF_RETURN_IF_ERROR(rcg->RegisterPartitions(env_, popts, func_def_lib_));
+
+ // Keeps the highest 8 bits 0x01: we reserve some bits of the
+ // step_id for future use.
+ const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+ TRACEPRINTF("stepid %llu", step_id);
+
+ TF_RETURN_IF_ERROR(rcg->RunPartitions(
+ env_, step_id, count, execution_state_.get(), &pss, opts, *req, resp));
+
+ pss.end_micros = Env::Default()->NowMicros();
+
+ // Schedule post-processing and cleanup to be done async.
+ rcg->Ref();
+ // TODO(tucker): We're doing the stats processing prior to returning
+ // the response to the client. Ensure it's safe to do so, then schedule
+ // in a closure.
+ SchedClosure([this, rcg, step_id]() {
+ Status s = rcg->CleanupPartitions(step_id);
+ if (!s.ok()) {
+ LOG(ERROR) << "Cleanup partition error: " << s;
+ }
+ rcg->Unref();
+ });
+
+ return Status::OK();
+}
+
+Status MasterSession::Close() {
+ std::vector<ReffedClientGraph*> to_unref;
+ {
+ mutex_lock l(mu_);
+ while (num_running_ != 0) {
+ num_running_is_zero_.wait(l);
+ }
+ ClearRunsTable(&to_unref, &runs_);
+ ClearRunsTable(&to_unref, &obsolete_);
+ }
+ for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
+ delete this;
+ return Status::OK();
+}
+
+} // end namespace
+
+namespace internal {
+
+MasterSessionInterface* NewMasterSession(const SessionOptions& options,
+ const MasterEnv* env,
+ std::vector<Device*>* remote_devs) {
+ return new MasterSession(options, env, remote_devs);
+}
+
+} // end namespace internal
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
new file mode 100644
index 0000000000..dc24c5c671
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -0,0 +1,38 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
+
+#include <vector>
+
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+class Device;
+class MasterEnv;
+class MasterSessionInterface;
+
+namespace internal {
+
+MasterSessionInterface* NewMasterSession(const SessionOptions& options,
+ const MasterEnv* env,
+ std::vector<Device*>* remote_devs);
+
+} // namespace internal
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_H_
diff --git a/tensorflow/core/distributed_runtime/master_session_interface.h b/tensorflow/core/distributed_runtime/master_session_interface.h
new file mode 100644
index 0000000000..9d6516bfc5
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_session_interface.h
@@ -0,0 +1,76 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+class ThreadPool;
+
+namespace tensorflow {
+
+class CallOptions;
+class GraphDef;
+class RunStepRequest;
+class RunStepResponse;
+class ExtendSessionRequest;
+class ExtendSessionResponse;
+
+// A "master session" encapsulates a distributed graph computation
+// (resource allocation, placement, execution, etc.).
+class MasterSessionInterface {
+ public:
+ // Initializes the Session with "def". Must be called before Extend(),
+ // Run(), or Close().
+ //
+ // The callee may clear "def".
+ virtual Status Create(GraphDef* def) = 0;
+
+ // Returns the session handle.
+ virtual const string& handle() const = 0;
+
+ // Returns the last access time (the number of micro-seconds since
+ // some fixed point in time) of this session.
+ virtual uint64 last_access_time_usec() const = 0;
+
+ // Attempt to extend the graph according to the given "req".
+ // (See master.proto for details of valid extensions.)
+ //
+ // PRECONDITION: The current version of this session's graph
+ // is "req->current_version".
+ //
+ // POSTCONDITION: The current version of this session's graph
+ // is "req->new_version".
+ //
+ // Extend() may block the caller thread for a long time.
+ virtual Status Extend(const ExtendSessionRequest* req,
+ ExtendSessionResponse* resp) = 0;
+
+ // Run one step.
+ virtual Status Run(CallOptions* opts, const RunStepRequest* req,
+ RunStepResponse* resp) = 0;
+
+ // Close this session and delete "*this". Returns OK if all known
+ // states are cleanup successfully.
+ //
+ // Close() may block the caller thread for a long time.
+ virtual Status Close() = 0;
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_SESSION_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/master_test.cc b/tensorflow/core/distributed_runtime/master_test.cc
new file mode 100644
index 0000000000..a0a3708100
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/master_test.cc
@@ -0,0 +1,423 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/master.h"
+
+#include <map>
+#include <memory>
+
+#include "external/grpc/include/grpc++/grpc++.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
+
+namespace tensorflow {
+
+class MasterTest : public ::testing::Test {
+ protected:
+ MasterTest() {
+ std::vector<string> targets;
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 1;
+ (*options.config.mutable_device_count())["GPU"] = 0;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 2, &cluster_));
+ master_ = grpc::MasterService::NewStub(
+ NewHostPortGrpcChannel(cluster_->targets()[0]));
+ }
+
+ std::unique_ptr<test::TestCluster> cluster_;
+ std::unique_ptr<grpc::MasterService::Stub> master_;
+
+ // Helpers for MasterService.{CreateSession,RunStep,CloseSession}
+ // rpc calls.
+
+ Status CreateSession(const GraphDef& def, string* handle,
+ int64* initial_version) {
+ ::grpc::ClientContext ctx;
+ CreateSessionRequest req;
+ *(req.mutable_graph_def()) = def;
+ // Invokes placement frequently.
+ req.mutable_config()->set_placement_period(1);
+ CreateSessionResponse resp;
+ const Status s = FromGrpcStatus(master_->CreateSession(&ctx, req, &resp));
+ if (s.ok()) {
+ *handle = resp.session_handle();
+ *initial_version = resp.graph_version();
+ }
+ return s;
+ }
+
+ Status ExtendSession(const string& handle, const GraphDef& def,
+ int64 current_version, int64* new_version) {
+ ::grpc::ClientContext ctx;
+ ExtendSessionRequest req;
+ req.set_session_handle(handle);
+ *(req.mutable_graph_def()) = def;
+ req.set_current_graph_version(current_version);
+ ExtendSessionResponse resp;
+ const Status s = FromGrpcStatus(master_->ExtendSession(&ctx, req, &resp));
+ if (s.ok()) {
+ *new_version = resp.new_graph_version();
+ }
+ return s;
+ }
+
+ Status RunStep(const string& handle,
+ const std::vector<std::pair<string, const Tensor*> >& feed,
+ const std::map<string, Tensor*>& fetch) {
+ ::grpc::ClientContext ctx;
+ RunStepRequest req;
+ req.set_session_handle(handle);
+ for (const auto& p : feed) {
+ const string& feed_name = p.first;
+ const Tensor* feed_tensor = p.second;
+ auto f = req.add_feed();
+ f->set_name(feed_name);
+ feed_tensor->AsProtoTensorContent(f->mutable_tensor());
+ }
+ for (const auto& p : fetch) {
+ const string& fetch_name = p.first;
+ req.add_fetch(fetch_name);
+ }
+ RunStepResponse resp;
+ const Status s = FromGrpcStatus(master_->RunStep(&ctx, req, &resp));
+ if (s.ok()) {
+ for (const auto& fetch_resp : resp.tensor()) {
+ auto it = fetch.find(fetch_resp.name());
+ CHECK(it != fetch.end());
+ CHECK(it->second->FromProto(fetch_resp.tensor()));
+ }
+ }
+ return s;
+ }
+
+ Status CloseSession(const string& handle) {
+ ::grpc::ClientContext ctx;
+ CloseSessionRequest req;
+ req.set_session_handle(handle);
+ CloseSessionResponse resp;
+ return FromGrpcStatus(master_->CloseSession(&ctx, req, &resp));
+ }
+
+ Status Reset() {
+ ::grpc::ClientContext ctx;
+ ResetRequest req;
+ ResetResponse resp;
+ return FromGrpcStatus(master_->Reset(&ctx, req, &resp));
+ }
+};
+
+TEST_F(MasterTest, CreateClose) {
+ GraphDef def; // Empty.
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def, &handle, &initial_version));
+ EXPECT_TRUE(errors::IsAborted(CloseSession("randombits")));
+ EXPECT_TRUE(CloseSession(handle).ok());
+}
+
+TEST_F(MasterTest, ListDevices) {
+ ::grpc::ClientContext ctx;
+ ListDevicesRequest req;
+ ListDevicesResponse resp;
+ const Status s = FromGrpcStatus(master_->ListDevices(&ctx, req, &resp));
+ TF_EXPECT_OK(s);
+ EXPECT_EQ(1, resp.local_device_size());
+ EXPECT_EQ("CPU", resp.local_device(0).device_type());
+}
+
+TEST_F(MasterTest, Reset) {
+ GraphDef def; // Empty.
+ string s1, s2;
+ int64 initial_version1, initial_version2;
+ TF_ASSERT_OK(CreateSession(def, &s1, &initial_version1));
+ TF_ASSERT_OK(CreateSession(def, &s2, &initial_version2));
+ EXPECT_TRUE(Reset().ok());
+ EXPECT_TRUE(errors::IsAborted(CloseSession(s1)));
+ EXPECT_TRUE(errors::IsAborted(CloseSession(s2)));
+}
+
+TEST_F(MasterTest, Extend) {
+ GraphDef def_0; // Empty.
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
+
+ Tensor A_expected(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&A_expected, {3.0, 2.0, -1.0, 0.0});
+
+ Tensor x_expected(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&x_expected, {2.0, 2.0});
+
+ Graph graph_1(OpRegistry::Global());
+ test::graph::Constant(&graph_1, A_expected, "A");
+ GraphDef def_1;
+ test::graph::ToGraphDef(&graph_1, &def_1);
+ int64 version_1;
+ TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
+ EXPECT_GT(version_1, initial_version);
+ Tensor A(DT_FLOAT, TensorShape({2, 2}));
+ TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
+ test::ExpectTensorEqual<float>(A, A_expected);
+
+ Graph graph_2(OpRegistry::Global());
+ test::graph::Constant(&graph_2, x_expected, "x");
+ GraphDef def_2;
+ test::graph::ToGraphDef(&graph_2, &def_2);
+ int64 version_2;
+ EXPECT_TRUE(errors::IsAborted(
+ ExtendSession("randombits", def_2, version_1, &version_2)));
+ TF_ASSERT_OK(ExtendSession(handle, def_2, version_1, &version_2));
+ EXPECT_GT(version_2, version_1);
+
+ Tensor x(DT_FLOAT, TensorShape({2, 1}));
+ TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"x:0", &x}}));
+ test::ExpectTensorEqual<float>(A, A_expected);
+ test::ExpectTensorEqual<float>(x, x_expected);
+
+ TF_ASSERT_OK(CloseSession(handle));
+}
+
+TEST_F(MasterTest, ExtendUpdateStatefulFails) {
+ GraphDef def_0; // Empty.
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
+
+ Graph graph_1(OpRegistry::Global());
+ test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
+ GraphDef def_1;
+ test::graph::ToGraphDef(&graph_1, &def_1);
+
+ int64 version_1, version_2;
+ TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
+ EXPECT_GT(version_1, initial_version);
+ EXPECT_TRUE(errors::IsInvalidArgument(
+ ExtendSession(handle, def_1, version_1, &version_2)));
+ TF_ASSERT_OK(CloseSession(handle));
+}
+
+TEST_F(MasterTest, ExtendTwiceFails) {
+ GraphDef def_0; // Empty.
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
+
+ Graph graph_1(OpRegistry::Global());
+ test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
+ GraphDef def_1;
+ test::graph::ToGraphDef(&graph_1, &def_1);
+
+ int64 version_1;
+ TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
+ EXPECT_GT(version_1, initial_version);
+ EXPECT_TRUE(errors::IsAborted(
+ ExtendSession(handle, def_1, initial_version, &version_1)));
+ TF_ASSERT_OK(CloseSession(handle));
+}
+
+TEST_F(MasterTest, ConcurrentExtendOnlyOneSucceeds) {
+ GraphDef def_0; // Empty.
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
+
+ Graph graph_1(OpRegistry::Global());
+ test::graph::Var(&graph_1, DT_FLOAT, TensorShape({512}));
+ GraphDef def_1;
+ test::graph::ToGraphDef(&graph_1, &def_1);
+
+ Notification n;
+ mutex mu;
+ int succeeded = 0;
+ int failed = 0;
+ auto extend_fn = [this, handle, def_1, initial_version, &n, &mu, &succeeded,
+ &failed]() {
+ n.WaitForNotification();
+ int64 new_version;
+ Status s = ExtendSession(handle, def_1, initial_version, &new_version);
+ EXPECT_TRUE(s.ok() || errors::IsAborted(s));
+ {
+ mutex_lock l(mu);
+ if (s.ok()) {
+ ++succeeded;
+ } else {
+ ++failed;
+ }
+ }
+ };
+
+ // Run 100 concurrent Extend calls and expect only one to succeed.
+ {
+ thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 100);
+ for (int i = 0; i < 100; ++i) {
+ thread_pool.Schedule(extend_fn);
+ }
+ n.Notify();
+ }
+
+ EXPECT_EQ(failed, 99);
+ EXPECT_EQ(succeeded, 1);
+ TF_ASSERT_OK(CloseSession(handle));
+}
+
+TEST_F(MasterTest, ConcurrentExtendAndRun) {
+ Graph graph_0(OpRegistry::Global());
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
+ test::graph::Constant(&graph_0, a_tensor, "A");
+ GraphDef def_0;
+ test::graph::ToGraphDef(&graph_0, &def_0);
+
+ string handle;
+ int64 initial_version;
+ TF_ASSERT_OK(CreateSession(def_0, &handle, &initial_version));
+
+ Graph graph_1(OpRegistry::Global());
+ Tensor b_tensor(DT_FLOAT, TensorShape({2, 2}));
+ test::FillValues<float>(&b_tensor, {1, 0, 0, 1});
+ test::graph::Constant(&graph_1, b_tensor, "B");
+ GraphDef def_1;
+ test::graph::ToGraphDef(&graph_1, &def_1);
+
+ Notification extend_done;
+ Notification extend_can_start;
+
+ auto get_a_fn = [this, handle, &extend_done]() {
+ Tensor A(DT_FLOAT, TensorShape({2, 2}));
+ while (!extend_done.HasBeenNotified()) {
+ TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
+ }
+ // Run at least once after the Extend has completed.
+ TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}}));
+ };
+
+ auto get_a_and_b_fn = [this, handle, &extend_done, &extend_can_start]() {
+ Tensor A(DT_FLOAT, TensorShape({2, 2}));
+ Tensor B(DT_FLOAT, TensorShape({2, 2}));
+
+ // Run at least once before the Extend has completed.
+ EXPECT_TRUE(
+ errors::IsNotFound(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}})));
+ extend_can_start.Notify();
+
+ // Concurrent with the Extend, we will either fail (as above), or
+ // succeed (as below).
+ while (!extend_done.HasBeenNotified()) {
+ Status s = RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}});
+ EXPECT_TRUE(errors::IsNotFound(s) || s.ok());
+ }
+
+ // Run at least once after the Extend has completed.
+ TF_ASSERT_OK(RunStep(handle, {}, {{"A:0", &A}, {"B:0", &B}}));
+ };
+
+ auto extend_fn = [this, handle, def_1, initial_version, &extend_done,
+ &extend_can_start]() {
+ extend_can_start.WaitForNotification();
+ int64 version_1;
+ TF_ASSERT_OK(ExtendSession(handle, def_1, initial_version, &version_1));
+ extend_done.Notify();
+ };
+
+ {
+ thread::ThreadPool thread_pool(Env::Default(), "extend_pool", 3);
+ thread_pool.Schedule(get_a_fn);
+ thread_pool.Schedule(get_a_and_b_fn);
+ thread_pool.Schedule(extend_fn);
+ }
+
+ TF_ASSERT_OK(CloseSession(handle));
+}
+
+TEST_F(MasterTest, EigenProblem) {
+ // A = [3 2; -1 0]; x = rand(2, 1);
+ // for i=1:100; x = A * x; end
+ // We'll try to compute the largest eigenvalue for A.
+ Graph graph(OpRegistry::Global());
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ // Store rows [3, 2] and [-1, 0] in row major format.
+ test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
+ Node* a_node = test::graph::Constant(&graph, a_tensor);
+
+ // x is from the feed.
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&x_tensor, {0, 0});
+ Node* x_node = test::graph::Constant(&graph, x_tensor);
+
+ // y = A * x
+ Node* y_node = test::graph::Matmul(&graph, a_node, x_node, false, false);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ string handle;
+ int64 initial_version;
+ TF_CHECK_OK(CreateSession(def, &handle, &initial_version));
+
+ // Temps supporting the computation of the convergence condition.
+ const Eigen::array<Eigen::DenseIndex, 1> sum_along_dim(0);
+ const Eigen::array<Eigen::DenseIndex, 2> matrix_transpose({1, 0});
+ Tensor x(DT_FLOAT, TensorShape({2, 1}));
+ Tensor y(DT_FLOAT, TensorShape({2, 1}));
+ Eigen::Tensor<float, 1, Eigen::RowMajor> y_square_sum;
+ Eigen::Tensor<float, 2, Eigen::RowMajor> y_normalized(2, 1);
+ y_normalized.setRandom();
+ Eigen::Tensor<float, 1, Eigen::RowMajor> error_square_sum;
+ float lambda;
+
+ // The computation loop.
+ bool converged = false;
+ while (!converged) {
+ // Run one step of the graph.
+ auto x_matrix = x.matrix<float>();
+ x_matrix = y_normalized;
+ TF_EXPECT_OK(
+ RunStep(handle, {{x_node->name(), &x}}, {{y_node->name() + ":0", &y}}));
+ auto y_matrix = y.matrix<float>();
+
+ // Client code computes the convergence condition.
+ {
+ lambda = y_matrix(0, 0) / x_matrix(0, 0);
+ y_square_sum = y.matrix<float>().square().sum(sum_along_dim);
+ const float norm = static_cast<float>(sqrt(y_square_sum(0)));
+ y_normalized = y_matrix * (1 / norm);
+ error_square_sum = (x_matrix - y_normalized).square().sum(sum_along_dim);
+ VLOG(1) << "x = [" << x_matrix.shuffle(matrix_transpose) << "] y = ["
+ << y_matrix.shuffle(matrix_transpose) << "] lambda = " << lambda;
+ converged = sqrt(error_square_sum(0)) < 1e-10;
+ }
+ }
+ EXPECT_NEAR(lambda, 2.0, 0.01);
+ TF_EXPECT_OK(CloseSession(handle));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/process_util.cc b/tensorflow/core/distributed_runtime/process_util.cc
new file mode 100644
index 0000000000..8f97382cf8
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/process_util.cc
@@ -0,0 +1,69 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/process_util.h"
+
+#include <string.h>
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/host_info.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+static thread::ThreadPool* InitComputePool(const SessionOptions& options) {
+ int32 inter_op_parallelism_threads =
+ options.config.inter_op_parallelism_threads();
+ if (inter_op_parallelism_threads == 0) {
+ // Default to using the number of cores available in the process.
+ inter_op_parallelism_threads = port::NumSchedulableCPUs();
+ }
+
+ return new thread::ThreadPool(Env::Default(), "Compute",
+ inter_op_parallelism_threads);
+}
+
+} // namespace
+
+thread::ThreadPool* ComputePool(const SessionOptions& options) {
+ static thread::ThreadPool* compute_pool = InitComputePool(options);
+ return compute_pool;
+}
+
+void SchedClosure(std::function<void()> closure) {
+ if (port::Tracing::IsActive()) {
+ const uint64 id = port::Tracing::UniqueId();
+ port::Tracing::RecordEvent(port::Tracing::EventCategory::kScheduleClosure,
+ id);
+ std::function<void()> wrapper = [closure, id]() {
+ port::Tracing::ScopedActivity region(
+ port::Tracing::EventCategory::kRunClosure, id);
+ closure();
+ };
+ Env::Default()->SchedClosure(wrapper);
+ } else {
+ Env::Default()->SchedClosure(closure);
+ }
+}
+
+void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure) {
+ Env::Default()->SchedClosureAfter(micros, closure);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/process_util.h b/tensorflow/core/distributed_runtime/process_util.h
new file mode 100644
index 0000000000..fb20e88b1e
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/process_util.h
@@ -0,0 +1,39 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
+
+#include <functional>
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// Returns a process-wide ThreadPool for scheduling compute operations
+// using 'options'. Caller does not take ownership over threadpool.
+thread::ThreadPool* ComputePool(const SessionOptions& options);
+
+// Schedule "closure" in the default thread queue.
+void SchedClosure(std::function<void()> closure);
+
+// Schedule "closure" after the given number of microseconds in the
+// fixed-size ThreadPool used for non-blocking compute tasks.
+void SchedNonBlockingClosureAfter(int micros, std::function<void()> closure);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_PROCESS_UTIL_H_
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
new file mode 100644
index 0000000000..387b9e4492
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -0,0 +1,91 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/remote_device.h"
+
+#include <vector>
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+using std::placeholders::_1;
+
+// TODO(zhifengc): We need to consolidate (full/partial) device name
+// parsing into one place.
+//
+// Parses and returns the local device part (e.g., cpu:0, gpu:4).
+string GetLocalDeviceName(StringPiece fullname) {
+ auto pos = fullname.rfind('/');
+ CHECK_NE(pos, StringPiece::npos);
+ fullname.remove_prefix(pos + 1);
+ return fullname.ToString();
+}
+
+class RemoteDevice : public Device {
+ public:
+ RemoteDevice(Env* env, const DeviceAttributes& da, WorkerInterface* wi)
+ : Device(env, da, nullptr),
+ local_dev_name_(GetLocalDeviceName(da.name())),
+ wi_(wi) {}
+
+ ~RemoteDevice() override { delete wi_; }
+ Status Sync() override { return Status::OK(); }
+ Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
+
+ private:
+ const string local_dev_name_;
+ WorkerInterface* wi_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RemoteDevice);
+};
+
+void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
+ const string& worker_name, NewRemoteDevicesDone done) {
+ WorkerInterface* wi = worker_cache->CreateWorker(worker_name);
+ if (wi == nullptr) {
+ std::vector<Device*> empty;
+ done(errors::NotFound("Device ", worker_name, " is not found."), &empty);
+ return;
+ }
+ struct Call {
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ };
+ Call* call = new Call;
+ auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
+ std::vector<Device*> remote_devices;
+ if (s.ok()) {
+ remote_devices.reserve(call->resp.device_attributes_size());
+ for (const DeviceAttributes& da : call->resp.device_attributes()) {
+ auto d =
+ new RemoteDevice(env, da, worker_cache->CreateWorker(worker_name));
+ remote_devices.push_back(d);
+ }
+ }
+ done(s, &remote_devices);
+ delete wi;
+ delete call;
+ };
+ wi->GetStatusAsync(&call->req, &call->resp, cb);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/remote_device.h b/tensorflow/core/distributed_runtime/remote_device.h
new file mode 100644
index 0000000000..aeefeda048
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/remote_device.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+class Device;
+class Env;
+class WorkerCacheInterface;
+
+// NewRemoteDevices discovers available devices on the
+// 'remote_worker'. The implementation uses 'channel_cache' to
+// discover how to communicate with the 'remote_worker' (via gRPC, for
+// example).
+//
+// NewRemoteDevices does not block.
+//
+// On success, the 'done' callback is given the OK status and a vector
+// of Device*. The caller should take ownership of these devices.
+//
+// Otherwise, the 'done' callback is given an error status and the
+// vector is empty.
+typedef std::function<void(const Status&, std::vector<Device*>*)>
+ NewRemoteDevicesDone;
+void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
+ const string& remote_worker, NewRemoteDevicesDone done);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_REMOTE_DEVICE_H_
diff --git a/tensorflow/core/distributed_runtime/remote_device_test.cc b/tensorflow/core/distributed_runtime/remote_device_test.cc
new file mode 100644
index 0000000000..c575a76471
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/remote_device_test.cc
@@ -0,0 +1,89 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/remote_device.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+const char* const kSession = "remote_session";
+
+class RemoteDeviceTest : public ::testing::Test {
+ protected:
+ string remote_name_;
+ std::unique_ptr<WorkerCacheInterface> worker_cache_;
+ std::unique_ptr<WorkerInterface> wi_;
+ std::vector<Device*> devices_;
+ std::unique_ptr<test::TestCluster> cluster_;
+
+ RemoteDeviceTest() {
+ SessionOptions options;
+ (*options.config.mutable_device_count())["CPU"] = 2;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 1, &cluster_));
+ const string& hostport = cluster_->targets()[0];
+ string host;
+ int port;
+ CHECK(RE2::FullMatch(hostport, "(.+):(\\d+)", &host, &port));
+ GrpcChannelSpec spec;
+ spec.AddHostPortsJob("localhost", {hostport}, 1);
+ worker_cache_.reset(NewGrpcWorkerCache(NewGrpcChannelCache(spec)));
+ remote_name_ = strings::StrCat("/job:", host, "/replica:0/task:0");
+ wi_.reset(worker_cache_->CreateWorker(remote_name_));
+ }
+
+ void SetUp() override {
+ Notification n;
+ NewRemoteDevices(Env::Default(), worker_cache_.get(), remote_name_,
+ [&n, this](const Status& s, std::vector<Device*>* found) {
+ TF_CHECK_OK(s);
+ devices_ = *found;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ EXPECT_EQ(devices_.size(), 2);
+ std::sort(devices_.begin(), devices_.end(), [](Device* a, Device* b) {
+ return a->name().compare(b->name()) < 0;
+ });
+ }
+
+ void TearDown() override {
+ for (auto d : devices_) delete d;
+ }
+};
+
+TEST_F(RemoteDeviceTest, GetStatus) {
+ // We know what the testlib's fake server does.
+ EXPECT_EQ(devices_[0]->name(), strings::StrCat(remote_name_, "/cpu:0"));
+ EXPECT_EQ(devices_[0]->attributes().device_type(),
+ DeviceType(DEVICE_CPU).type());
+ EXPECT_EQ(devices_[0]->attributes().memory_limit(), 256 << 20);
+ EXPECT_EQ(devices_[1]->name(), strings::StrCat(remote_name_, "/cpu:1"));
+ EXPECT_EQ(devices_[1]->attributes().memory_limit(), 256 << 20);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
new file mode 100644
index 0000000000..6a71bb04b4
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h
@@ -0,0 +1,79 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
+
+#include <string>
+
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received. Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+// fork execution of an graph executor using "rendez" on thread 1;
+// fork execution of another graph executor using "rendez" on thread 2;
+// ...
+// join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through the "rend".
+//
+// Tensors sent and recved through rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
+class RendezvousMgrInterface {
+ public:
+ RendezvousMgrInterface() {}
+ virtual ~RendezvousMgrInterface() {}
+
+ // Returns Rendezvous supporting send and recv among workers in the
+ // "step_id". The caller takes ownership of one reference on the
+ // returned Rendezvous instance.
+ virtual Rendezvous* Find(int64 step_id) = 0;
+
+ // Finds the local rendezvous instance for the "step_id". Runs
+ // "done" when the tensor for "key" is produced or an error occurs.
+ //
+ // This method is used by the rpc handler of RecvTensor.
+ virtual void RecvLocalAsync(int64 step_id, const string& key,
+ Rendezvous::DoneCallback done) = 0;
+
+ // Synchronous wrapper for RecvLocalAsync.
+ virtual Status RecvLocal(int64 step_id, const string& key, Tensor* val,
+ bool* is_dead) = 0;
+
+ // Removes rendezvous for "step_id".
+ //
+ // TODO(zhifengc): Have a background thread in worker that
+ // periodically calls CleanupAll().
+ virtual void Cleanup(int64 step_id) = 0;
+
+ // Removes all rendezvous.
+ virtual void CleanupAll() = 0;
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RENDEZVOUS_MGR_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
new file mode 100644
index 0000000000..3166c94259
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -0,0 +1,341 @@
+# Description:
+# RPC communication interfaces and implementations for TensorFlow.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+filegroup(
+ name = "c_srcs",
+ data = glob([
+ "**/*.cc",
+ "**/*.h",
+ ]),
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cuda_library",
+ "tf_cc_tests",
+)
+
+# For platform specific build config
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_kernel_tests_linkstatic",
+)
+load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+
+package(default_visibility = [
+ "//tensorflow:internal",
+])
+
+cc_library(
+ name = "grpc_util",
+ srcs = [],
+ hdrs = ["grpc_util.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "grpc_client_cq_tag",
+ srcs = [],
+ hdrs = ["grpc_client_cq_tag.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "grpc_remote_worker",
+ srcs = ["grpc_remote_worker.cc"],
+ hdrs = ["grpc_remote_worker.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_client_cq_tag",
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:worker_proto_cc",
+ "//tensorflow/core:worker_service_proto_cc",
+ "//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:worker_cache_logger",
+ "//tensorflow/core/distributed_runtime:worker_interface",
+ ],
+)
+
+cc_library(
+ name = "grpc_channel",
+ srcs = ["grpc_channel.cc"],
+ hdrs = ["grpc_channel.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "grpc_call",
+ srcs = [],
+ hdrs = ["grpc_call.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "async_service_interface",
+ srcs = [],
+ hdrs = ["async_service_interface.h"],
+ deps = [],
+)
+
+cc_library(
+ name = "grpc_worker_cache",
+ srcs = ["grpc_worker_cache.cc"],
+ hdrs = ["grpc_worker_cache.h"],
+ deps = [
+ ":grpc_channel",
+ ":grpc_client_cq_tag",
+ ":grpc_remote_worker",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:worker_cache",
+ "//tensorflow/core/distributed_runtime:worker_cache_logger",
+ "//tensorflow/core/distributed_runtime:worker_cache_partial",
+ "//tensorflow/core/distributed_runtime:worker_interface",
+ ],
+)
+
+cc_library(
+ name = "grpc_worker_service",
+ srcs = ["grpc_worker_service.cc"],
+ hdrs = ["grpc_worker_service.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":async_service_interface",
+ ":grpc_call",
+ ":grpc_util",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_runtime",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:worker_proto_cc",
+ "//tensorflow/core:worker_service_proto_cc",
+ "//tensorflow/core/distributed_runtime:graph_mgr",
+ "//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
+ "//tensorflow/core/distributed_runtime:worker_cache",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime:worker_interface",
+ ],
+)
+
+cc_library(
+ name = "grpc_remote_master",
+ srcs = ["grpc_remote_master.cc"],
+ hdrs = ["grpc_remote_master.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
+ "//tensorflow/core/distributed_runtime:master_interface",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "grpc_master_service",
+ srcs = ["grpc_master_service.cc"],
+ hdrs = ["grpc_master_service.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":async_service_interface",
+ ":grpc_call",
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:master_service_proto_cc",
+ "//tensorflow/core/distributed_runtime:master",
+ "//tensorflow/core/distributed_runtime:master_interface",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "rpc_rendezvous_mgr",
+ srcs = ["rpc_rendezvous_mgr.cc"],
+ hdrs = ["rpc_rendezvous_mgr.h"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
+ "//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:worker_cache",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ "//tensorflow/core/distributed_runtime:worker_interface",
+ ],
+)
+
+cc_library(
+ name = "grpc_server_lib",
+ srcs = [
+ "grpc_server_lib.cc",
+ ],
+ hdrs = ["grpc_server_lib.h"],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":async_service_interface",
+ ":grpc_channel",
+ ":grpc_master_service",
+ ":grpc_worker_cache",
+ ":grpc_worker_service",
+ ":rpc_rendezvous_mgr",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/distributed_runtime:graph_mgr",
+ "//tensorflow/core/distributed_runtime:master_env",
+ "//tensorflow/core/distributed_runtime:master_session",
+ "//tensorflow/core/distributed_runtime:process_util",
+ "//tensorflow/core/distributed_runtime:worker_env",
+ ],
+)
+
+cc_binary(
+ name = "grpc_tensorflow_server",
+ srcs = [
+ "grpc_tensorflow_server.cc",
+ ],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_server_lib",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cuda_library(
+ name = "grpc_testlib_ops",
+ testonly = 1,
+ srcs = ["grpc_testlib_ops.cc"],
+ linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
+ deps = [
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+cc_binary(
+ name = "grpc_testlib_server",
+ testonly = 1,
+ srcs = [
+ "grpc_testlib_server.cc",
+ ],
+ deps = [
+ "@grpc//:grpc++_unsecure",
+ ":grpc_server_lib",
+ ":grpc_testlib_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cuda_library(
+ name = "grpc_testlib",
+ testonly = 1,
+ srcs = ["grpc_testlib.cc"],
+ hdrs = ["grpc_testlib.h"],
+ data = [
+ ":grpc_testlib_server",
+ ],
+ deps = [
+ ":grpc_session",
+ ":grpc_testlib_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "grpc_session",
+ srcs = ["grpc_session.cc"],
+ hdrs = ["grpc_session.h"],
+ deps = [
+ ":grpc_channel",
+ ":grpc_remote_master",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/distributed_runtime:master_interface",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_tests(
+ linkstatic = tf_kernel_tests_linkstatic(),
+ tags = tf_cuda_tests_tags(),
+ tests = [
+ "grpc_channel_test.cc",
+ "grpc_session_test.cc",
+ "rpc_rendezvous_mgr_test.cc",
+ ],
+ deps = [
+ ":grpc_channel",
+ ":grpc_session",
+ ":grpc_testlib",
+ ":rpc_rendezvous_mgr",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:master_proto_cc",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/distributed_runtime:process_util",
+ ],
+)
diff --git a/tensorflow/core/distributed_runtime/rpc/async_service_interface.h b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h
new file mode 100644
index 0000000000..2f453b048e
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/async_service_interface.h
@@ -0,0 +1,37 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
+
+namespace tensorflow {
+
+// Represents an abstract asynchronous service that handles incoming
+// RPCs with a polling loop.
+class AsyncServiceInterface {
+ public:
+ virtual ~AsyncServiceInterface() {}
+
+ // A blocking method that should be called to handle incoming RPCs.
+ // This method will block until the service is shutdown, which
+ // depends on the implementation of the service.
+ virtual void HandleRPCsLoop() = 0;
+
+ // TODO(mrry): Add a clean shutdown method?
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_ASYNC_SERVICE_INTERFACE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_call.h b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
new file mode 100644
index 0000000000..11f139ca03
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_call.h
@@ -0,0 +1,227 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
+
+#include "tensorflow/core/platform/macros.h"
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "external/grpc/include/grpc++/server_builder.h"
+
+namespace tensorflow {
+
+// CALL STRUCTURES
+// ===============
+//
+// Each pending (incoming) request corresponds to a call object that
+// encapsulates the state of the call. Templates and
+// pointers-to-member functions are used to avoid boilerplate and
+// redundant closure creation. The class hierarchy is as follows:
+//
+// * `UntypedCall<Service>`: The base class represents a call that
+// could be associated with any of the methods on a service of type
+// `Service`. Also defines a `Tag` nested class that can be used as
+// the tag in a `grpc::CompletionQueue`. Each class that
+// instantiates `Service` should have a completion queue polling
+// loop that knows about `UntypedCall<Service>::Tag` objects, and
+// invokes their `OnCompleted()` method to continue processing.
+//
+// * `Call<Service, GrpcService, Req, Resp>`: This class extends
+// `UntypedCall<Service>` and is additionally parameterized by the
+// gRPC-generated asynchronous service class, and the request and
+// response message types. It defines the state associated with a
+// call (whose type depends on the message types), and stores a
+// pointer to a `Service::HandleFoo()` handler method. Each
+// `Service::HandleFoo()` method knows about the corresponding
+// `Call` type, in order to access its state, and invoke its
+// `SendResponse()` method.
+//
+// The lifecycle of a call object is as follows.
+//
+// 1. A `Service` creates a `Call` for a particular method and
+// enqueues it in its completion queue (via an
+// `UntypedCall<Service>::Tag`).
+//
+// 2. When the tag is returned from `cq_->Next()`, the
+// `UntypedCall::RequestReceived()` method is invoked and takes
+// ownership of the call object. This indirectly invokes the
+// appropriate handler method on `Service`.
+//
+// 3. After the response has been written (perhaps in another thread),
+// the `Call::SendResponse()` method is invoked. It transfers
+// ownership of the call object back to the completion queue (via
+// an `UntypedCall::Tag`).
+//
+// 4. When the response has been sent, the tag is returned from
+// `cq_->Next()`, and the call object is deleted.
+
+// Represents a pending request with unknown message types.
+template <class Service>
+class UntypedCall : public core::RefCounted {
+ public:
+ virtual ~UntypedCall() {}
+
+ // The implementation of this method should use `service` to handle
+ // an incoming request, and (perhaps asynchronously) send the
+ // response.
+ //
+ // One reference on `this` is transferred to the callee, and the
+ // callee is responsible for releasing it (typically via
+ // `Call::SendResponse()`).
+ //
+ // `ok` is true if the request was received in a "regular event",
+ // otherwise false.
+ virtual void RequestReceived(Service* service, bool ok) = 0;
+
+ // This method will be called when the response has been sent by
+ // `service` and the call is no longer used.
+ //
+ // `ok` is true if the response sending completed as a "regular
+ // event", otherwise it is false.
+ void ResponseSent(Service* service, bool ok) {}
+
+ // This method will be called either (i) when the server is notified
+ // that the request has been cancelled, or (ii) when the request completes
+ // normally. The implementation should distinguish these cases by querying
+ // the `grpc::ServerContext` associated with the request.
+ virtual void RequestCancelled(Service* service, bool ok) = 0;
+
+ // Associates a tag in a `::grpc::CompletionQueue` with a callback
+ // for an incoming RPC. A Tag owns a reference on the corresponding
+ // Call object.
+ class Tag {
+ public:
+ using Callback = void (UntypedCall::*)(Service*, bool);
+
+ // Creates a new `Tag` for the given `UntypedCall`. When the
+ // request associated with this tag is complete, `callback` will
+ // be called.
+ Tag(UntypedCall* call, Callback callback)
+ : call_(call), callback_(callback) {
+ call_->Ref();
+ }
+
+ ~Tag() { call_->Unref(); }
+
+ // Calls the callback associated with this tag.
+ //
+ // The callback takes ownership of `this->call_`.
+ void OnCompleted(Service* service, bool ok) {
+ (call_->*callback_)(service, ok);
+ }
+
+ private:
+ UntypedCall* call_; // `this` owns one reference.
+ Callback callback_;
+ };
+};
+
+// Represents a pending call with known request and response message
+// types, and a known request-handling method.
+template <class Service, class GrpcService, class RequestMessage,
+ class ResponseMessage>
+class Call : public UntypedCall<Service> {
+ public:
+ // Represents the generic signature of a generated
+ // `GrpcService::RequestFoo()` method, where `Foo` is the name of an
+ // RPC method.
+ using EnqueueFunction = void (GrpcService::*)(
+ ::grpc::ServerContext*, RequestMessage*,
+ ::grpc::ServerAsyncResponseWriter<ResponseMessage>*,
+ ::grpc::CompletionQueue*, ::grpc::ServerCompletionQueue*, void*);
+
+ // Represents the generic signature of a `Service::HandleFoo()`
+ // method, where `Foo` is the name of an RPC method.
+ using HandleRequestFunction = void (Service::*)(
+ Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
+
+ Call(HandleRequestFunction handle_request_function)
+ : handle_request_function_(handle_request_function), responder_(&ctx_) {}
+
+ virtual ~Call() {}
+
+ void RequestReceived(Service* service, bool ok) override {
+ if (ok) {
+ this->Ref();
+ (service->*handle_request_function_)(this);
+ }
+ }
+
+ void SendResponse(::grpc::Status status) {
+ responder_.Finish(response, status,
+ new typename UntypedCall<Service>::Tag(
+ this, &UntypedCall<Service>::ResponseSent));
+ this->Unref();
+ }
+
+ void RequestCancelled(Service* service, bool ok) override {
+ if (ctx_.IsCancelled()) {
+ mutex_lock l(mu_);
+ if (cancel_callback_) {
+ cancel_callback_();
+ }
+ }
+ }
+
+ // Registers `callback` as the function that should be called if and when this
+ // call is cancelled by the client.
+ void SetCancelCallback(std::function<void()> callback) {
+ mutex_lock l(mu_);
+ cancel_callback_ = callback;
+ }
+
+ // Clears any cancellation callback that has been registered for this call.
+ void ClearCancelCallback() {
+ mutex_lock l(mu_);
+ cancel_callback_ = nullptr;
+ }
+
+ // Enqueues a new request for the given service on the given
+ // completion queue, using the given `enqueue_function`.
+ //
+ // The request will be handled with the given
+ // `handle_request_function`.
+ static void EnqueueRequest(GrpcService* grpc_service,
+ ::grpc::ServerCompletionQueue* cq,
+ EnqueueFunction enqueue_function,
+ HandleRequestFunction handle_request_function) {
+ auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
+ handle_request_function);
+
+ call->ctx_.AsyncNotifyWhenDone(new typename UntypedCall<Service>::Tag(
+ call, &UntypedCall<Service>::RequestCancelled));
+
+ (grpc_service->*enqueue_function)(
+ &call->ctx_, &call->request, &call->responder_, cq, cq,
+ new typename UntypedCall<Service>::Tag(
+ call, &UntypedCall<Service>::RequestReceived));
+ call->Unref();
+ }
+
+ RequestMessage request;
+ ResponseMessage response;
+
+ private:
+ HandleRequestFunction handle_request_function_;
+ ::grpc::ServerContext ctx_;
+ ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
+ mutex mu_;
+ std::function<void()> cancel_callback_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CALL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
new file mode 100644
index 0000000000..f9492114b6
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -0,0 +1,314 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+
+#include <unordered_map>
+
+#include "external/grpc/include/grpc++/create_channel.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+RE2* kTargetRE = new RE2("^/job:([^/]+)/replica:([0-9]+)/task:([0-9]+)$");
+RE2* kHostPortRE = new RE2("([^:/]+):(\\d+)");
+RE2* kSparseHostPortRE = new RE2("(\\d+):([^:/]+):(\\d+)");
+
+string MakeAddress(const string& job, int replica, int task) {
+ return strings::StrCat("/job:", job, "/replica:", replica, "/task:", task);
+}
+
+} // namespace
+
+SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target) {
+ // TODO(mrry): Implement secure channels.
+ return ::grpc::CreateChannel(target, ::grpc::InsecureChannelCredentials());
+}
+
+Status GrpcChannelSpec::AddHostPortsJob(const string& job_id,
+ const std::vector<string>& host_ports,
+ int tasks_per_replica) {
+ if (!job_ids_.insert(job_id).second) {
+ return errors::InvalidArgument(
+ "Duplicate job ID in cluster specification: ", job_id);
+ }
+ HostPortsJob job;
+ job.job_id = job_id;
+ for (const string& host_port : host_ports) {
+ string host;
+ int port;
+ if (!RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) {
+ return errors::InvalidArgument("Could not interpret \"", host_port,
+ "\" as a host-port pair.");
+ }
+ }
+ job.host_ports = host_ports;
+ job.tasks_per_replica = tasks_per_replica;
+ host_ports_jobs_.push_back(job);
+ return Status::OK();
+}
+
+GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& spec) {
+ const int num_jobs = spec.host_ports_jobs().size();
+ if (!num_jobs) {
+ LOG(ERROR) << "Empty channel spec.";
+ return nullptr;
+ }
+ std::vector<GrpcChannelCache*> caches;
+ caches.reserve(num_jobs);
+ for (const GrpcChannelSpec::HostPortsJob& job : spec.host_ports_jobs()) {
+ caches.push_back(NewHostPortsGrpcChannelCache(job.job_id, job.host_ports,
+ job.tasks_per_replica));
+ }
+ return caches.size() == 1 ? caches[0] : NewMultiGrpcChannelCache(caches);
+}
+
+// GrpcChannelCache that caches results to FindWorkerChannel() calls.
+class CachingGrpcChannelCache : public GrpcChannelCache {
+ public:
+ CachingGrpcChannelCache() {}
+
+ ~CachingGrpcChannelCache() override {}
+
+ SharedGrpcChannelPtr FindWorkerChannel(const string& target) override {
+ SharedGrpcChannelPtr ch = nullptr;
+ {
+ mutex_lock l(mu_); // could use reader lock
+ ch = gtl::FindPtrOrNull(channels_, target);
+ if (ch) {
+ return ch;
+ }
+ }
+ ch = FindChannelOnce(target);
+ if (ch) {
+ mutex_lock l(mu_);
+ channels_.insert({target, ch});
+ }
+ return ch;
+ }
+
+ protected:
+ // Find the ClientChannel for "target". Only called when no channel was
+ // found in the channels_ cache for "target". A non nullptr result will be
+ // cached in channels_.
+ virtual SharedGrpcChannelPtr FindChannelOnce(const string& target) = 0;
+
+ private:
+ // TODO(zhifengc): Eviction when the map becomes too big.
+ mutex mu_;
+ std::unordered_map<string, SharedGrpcChannelPtr> channels_ GUARDED_BY(mu_);
+};
+
+// A ChannelCache that is the union of multiple ChannelCaches.
+// Takes ownership of the caches passed to the constructor.
+class MultiGrpcChannelCache : public CachingGrpcChannelCache {
+ public:
+ explicit MultiGrpcChannelCache(const std::vector<GrpcChannelCache*>& caches)
+ : CachingGrpcChannelCache(), caches_(caches) {}
+
+ ~MultiGrpcChannelCache() override {
+ for (GrpcChannelCache* cache : caches_) {
+ delete cache;
+ }
+ }
+
+ void ListWorkers(std::vector<string>* workers) override {
+ for (GrpcChannelCache* cache : caches_) {
+ cache->ListWorkers(workers);
+ }
+ }
+
+ string TranslateTask(const string& target) override {
+ mutex_lock l(mu_); // could use reader lock
+ GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
+ if (cache == nullptr) {
+ for (GrpcChannelCache* c : caches_) {
+ string r = c->TranslateTask(target);
+ if (!r.empty()) {
+ target_caches_.insert({target, c});
+ cache = c;
+ break;
+ }
+ }
+ }
+ CHECK(cache) << "Could not find GrpcChannelCache holding channel for "
+ << target;
+ return cache->TranslateTask(target);
+ }
+
+ protected:
+ SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
+ for (GrpcChannelCache* cache : caches_) {
+ SharedGrpcChannelPtr ch(cache->FindWorkerChannel(target));
+ if (ch) {
+ mutex_lock l(mu_);
+ target_caches_.insert({target, cache});
+ return ch;
+ }
+ }
+ return nullptr;
+ }
+
+ private:
+ // List of channels used by this MultiGrpcChannelCache.
+ const std::vector<GrpcChannelCache*> caches_;
+
+ mutex mu_;
+ // Cache of channels keyed by the target they are handling.
+ // The same GrpcChannelCache can appear multiple times in the cache.
+ std::unordered_map<string, GrpcChannelCache*> target_caches_ GUARDED_BY(mu_);
+};
+
+GrpcChannelCache* NewMultiGrpcChannelCache(
+ const std::vector<GrpcChannelCache*>& caches) {
+ return new MultiGrpcChannelCache(caches);
+}
+
+class HostPortsGrpcChannelCache : public CachingGrpcChannelCache {
+ public:
+ HostPortsGrpcChannelCache(const string& job_id,
+ const std::vector<string>& host_ports,
+ int tasks_per_replica)
+ : job_id_(job_id),
+ host_ports_(BuildDenseHostPortsList(host_ports, tasks_per_replica)),
+ tasks_per_replica_(tasks_per_replica) {
+ LOG(INFO) << "Initialize HostPortsGrpcChannelCache for job " << job_id
+ << " -> {" << str_util::Join(host_ports, ", ") << "}";
+ }
+ ~HostPortsGrpcChannelCache() override {}
+
+ void ListWorkers(std::vector<string>* workers) override {
+ int num_host_ports = 0;
+ for (size_t i = 0; i < host_ports_.size(); ++i) {
+ if (!host_ports_[i].empty()) {
+ ++num_host_ports;
+ }
+ }
+ workers->reserve(workers->size() + num_host_ports);
+ for (size_t i = 0; i < host_ports_.size(); ++i) {
+ if (!host_ports_[i].empty()) {
+ workers->emplace_back(MakeAddress(job_id_, i / tasks_per_replica_,
+ i % tasks_per_replica_));
+ }
+ }
+ }
+
+ string TranslateTask(const string& target) override {
+ RegexpStringPiece job;
+ int32 replica;
+ int32 task;
+ if (!RE2::FullMatch(target, *kTargetRE, &job, &replica, &task)) {
+ LOG(WARNING) << "Invalid target: " << target;
+ return "";
+ }
+ if (job != job_id_) {
+ return "";
+ }
+ if (task >= tasks_per_replica_) {
+ LOG(WARNING) << "Task out of bounds for job " << job_id_ << ": " << task;
+ return "";
+ }
+ const size_t i = replica * tasks_per_replica_ + task;
+ if (i >= host_ports_.size()) {
+ LOG(WARNING) << "Replica/task out of bounds for job " << job_id_ << ": "
+ << target;
+ return "";
+ }
+ if (host_ports_[i].empty()) {
+ LOG(WARNING) << "Replica/task not in sparse index:host:port list for job "
+ << job_id_ << ": " << target;
+ return "";
+ }
+ return host_ports_[i];
+ }
+
+ protected:
+ static std::vector<string> BuildDenseHostPortsList(
+ const std::vector<string>& host_ports, int tasks_per_replica) {
+ std::map<int, string> sparse_host_ports;
+ for (const string& host_port : host_ports) {
+ int i = -1;
+ string host;
+ int port = -1;
+ if (RE2::FullMatch(host_port, *kSparseHostPortRE, &i, &host, &port)) {
+ CHECK_LE(0, i);
+ CHECK_LE(0, port);
+ CHECK(sparse_host_ports.find(i) == sparse_host_ports.end())
+ << "Duplicate index " << i << ": {"
+ << str_util::Join(host_ports, ", ") << "}";
+ sparse_host_ports[i] = strings::StrCat(host, ":", port);
+ } else {
+ CHECK(RE2::FullMatch(host_port, *kHostPortRE, &host, &port))
+ << host_port
+ << " does not look like a host:port or an index:host:port";
+ }
+ }
+
+ if (sparse_host_ports.empty()) {
+ // The input is a dense list; return it directly.
+ return host_ports;
+ }
+
+ // The input is a sparse list. Convert it to a dense list.
+ CHECK_EQ(host_ports.size(), sparse_host_ports.size())
+ << "Mix of host:port and index:host:port: {"
+ << str_util::Join(host_ports, ", ") << "}";
+ int num_tasks = sparse_host_ports.rbegin()->first + 1;
+ if (num_tasks % tasks_per_replica != 0) {
+ num_tasks = (num_tasks / tasks_per_replica + 1) * tasks_per_replica;
+ }
+ std::vector<string> dense_host_ports;
+ dense_host_ports.resize(num_tasks);
+ for (const auto& p : sparse_host_ports) {
+ dense_host_ports[p.first] = p.second;
+ }
+ return dense_host_ports;
+ }
+
+ SharedGrpcChannelPtr FindChannelOnce(const string& target) override {
+ const string host_port = TranslateTask(target);
+ if (host_port.empty()) {
+ LOG(WARNING) << "Could not find channel for target: " << target;
+ return nullptr;
+ }
+ return NewHostPortGrpcChannel(host_port);
+ }
+
+ private:
+ const string job_id_;
+ const std::vector<string> host_ports_;
+ const int tasks_per_replica_;
+ TF_DISALLOW_COPY_AND_ASSIGN(HostPortsGrpcChannelCache);
+};
+
+GrpcChannelCache* NewHostPortsGrpcChannelCache(
+ const string& job_id, const std::vector<string>& host_ports,
+ int tasks_per_replica) {
+ return new HostPortsGrpcChannelCache(job_id, host_ports, tasks_per_replica);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
new file mode 100644
index 0000000000..f3667a567e
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -0,0 +1,98 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
+
+#include <memory>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "external/grpc/include/grpc++/grpc++.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+
+namespace tensorflow {
+
+// Consolidated parameter structure to ease use of generic interfaces.
+//
+// Each job_id requires:
+// - a list of host:port (or sparse list of index:host:port)
+// - the number of tasks per replica
+class GrpcChannelSpec {
+ public:
+ struct HostPortsJob {
+ string job_id;
+ std::vector<string> host_ports;
+ int tasks_per_replica;
+ };
+
+ Status AddHostPortsJob(const string& job_id,
+ const std::vector<string>& host_ports,
+ int tasks_per_replica);
+
+ const std::vector<HostPortsJob>& host_ports_jobs() const {
+ return host_ports_jobs_;
+ }
+
+ private:
+ std::vector<HostPortsJob> host_ports_jobs_;
+ std::set<string> job_ids_;
+};
+
+class GrpcChannelCache {
+ public:
+ virtual ~GrpcChannelCache() {}
+
+ // Populates *workers with names of all workers which this object
+ // was created to handle. Worker names are in the format
+ // /job:<job identifier>/task:<task id>
+ // e.g. /job:mnist/task:2
+ virtual void ListWorkers(std::vector<string>* workers) = 0;
+
+ // If found, returns a gRPC channel that is connected to the remote
+ // worker named by 'target'. 'target' is of the following
+ // format: /job:<job identifier>/task:<task id>
+ // E.g., /job:mnist/task:2
+ virtual SharedGrpcChannelPtr FindWorkerChannel(const string& target) = 0;
+
+ // Translates a string in the form `/job:X/task:Z` into a host_port.
+ virtual string TranslateTask(const string& task) = 0;
+};
+
+GrpcChannelCache* NewGrpcChannelCache(const GrpcChannelSpec& p);
+
+// Below here are internal-only functions.
+
+SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target);
+
+// Returns a ChannelCache that uses a set of known host:port pairs. E.g., say,
+// job_id = 'mnist', 'host_ports' = {"h0:0", "h1:1", ..., "h11:11", "h12:12"},
+// tasks_per_replica = 8, /job:mnist/replica:1/task:3 is mapped to host:port
+// "h11:11" (11 = 8 * 1 + 3).
+//
+// The caller takes ownership of the returned object.
+GrpcChannelCache* NewHostPortsGrpcChannelCache(
+ const string& job_id, const std::vector<string>& host_ports,
+ int tasks_per_replica);
+
+// Returns a ChannelCache that is the union of a number of other ChannelCaches.
+GrpcChannelCache* NewMultiGrpcChannelCache(
+ const std::vector<GrpcChannelCache*>& caches);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CHANNEL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
new file mode 100644
index 0000000000..a951dc2fcf
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+#define IsSameAddrSp DeviceNameUtils::IsSameAddressSpace
+
+TEST(GrpcChannelTest, IsSameAddressSpace) {
+ // Same.
+ EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
+ "/job:mnist/replica:10/task:10/cpu:1"));
+ EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:0",
+ "/job:mnist/replica:10/task:10/gpu:2"));
+ EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10",
+ "/job:mnist/replica:10/task:10/gpu:2"));
+ EXPECT_TRUE(IsSameAddrSp("/job:mnist/replica:10/task:10/cpu:1",
+ "/job:mnist/replica:10/task:10"));
+
+ // Different.
+ EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:9/cpu:0",
+ "/job:mnist/replica:10/task:10/cpu:0"));
+ EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:9/task:10/cpu:0",
+ "/job:mnist/replica:10/task:10/cpu:0"));
+ EXPECT_FALSE(IsSameAddrSp("/job:MNIST/replica:10/task:10/cpu:0",
+ "/job:mnist/replica:10/task:10/cpu:0"));
+
+ // Invalid names.
+ EXPECT_FALSE(IsSameAddrSp("random_invalid_target", "random_invalid_target"));
+ EXPECT_FALSE(IsSameAddrSp("/job:/replica:10/task:10/cpu:0",
+ "/job:/replica:10/task:10/cpu:1"));
+ EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:xx/task:10/cpu:0",
+ "/job:mnist/replica:xx/task:10/cpu:1"));
+ EXPECT_FALSE(IsSameAddrSp("/job:mnist/replica:10/task:yy/cpu:0",
+ "/job:mnist/replica:10/task:yy/cpu:1"));
+}
+
+TEST(GrpcChannelTest, HostPorts) {
+ std::unique_ptr<GrpcChannelCache> cc(NewHostPortsGrpcChannelCache(
+ "mnist", {"a:1", "b:2", "c:3", "d:4", "e:5", "f:6"}, 2));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
+
+ {
+ // NOTE(mrry): The gRPC channel doesn't expose the target, so we
+ // can't compare it for equality.
+ auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
+ auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
+
+ auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
+ auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
+
+ auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
+ auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
+
+ EXPECT_EQ(a_1_1.get(), a_1_2.get());
+ EXPECT_EQ(d_4_1.get(), d_4_2.get());
+ EXPECT_EQ(e_5_1.get(), e_5_2.get());
+
+ EXPECT_NE(a_1_1.get(), d_4_2.get());
+ EXPECT_NE(a_1_1.get(), e_5_2.get());
+ EXPECT_NE(d_4_1.get(), e_5_2.get());
+ }
+
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:1/task:0",
+ "/job:mnist/replica:1/task:1",
+ "/job:mnist/replica:2/task:0",
+ "/job:mnist/replica:2/task:1"}),
+ workers);
+}
+
+TEST(GrpcChannelTest, SparseHostPorts) {
+ std::unique_ptr<GrpcChannelCache> cc(
+ NewHostPortsGrpcChannelCache("mnist", {"0:a:1", "3:d:4", "4:e:5"}, 2));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("invalid_target"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:other/replica:0/task:0"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:1"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:0/task:2"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:1/task:0"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:2/task:1"));
+ EXPECT_EQ(nullptr, cc->FindWorkerChannel("/job:mnist/replica:3/task:0"));
+
+ {
+ // NOTE(mrry): The gRPC channel doesn't expose the target, so we
+ // can't compare it for equality.
+ auto a_1_1 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
+ auto a_1_2 = cc->FindWorkerChannel("/job:mnist/replica:0/task:0");
+
+ auto d_4_1 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
+ auto d_4_2 = cc->FindWorkerChannel("/job:mnist/replica:1/task:1");
+
+ auto e_5_1 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
+ auto e_5_2 = cc->FindWorkerChannel("/job:mnist/replica:2/task:0");
+
+ EXPECT_EQ(a_1_1.get(), a_1_2.get());
+ EXPECT_EQ(d_4_1.get(), d_4_2.get());
+ EXPECT_EQ(e_5_1.get(), e_5_2.get());
+
+ EXPECT_NE(a_1_1.get(), d_4_2.get());
+ EXPECT_NE(a_1_1.get(), e_5_2.get());
+ EXPECT_NE(d_4_1.get(), e_5_2.get());
+ }
+
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:1/task:1",
+ "/job:mnist/replica:2/task:0"}),
+ workers);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
new file mode 100644
index 0000000000..300481303b
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h
@@ -0,0 +1,56 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
+
+#include "external/grpc/include/grpc++/grpc++.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// Represents a pending asynchronous client call as a tag that can be
+// stored in a `grpc::CompletionQueue`.
+class GrpcClientCQTag {
+ public:
+ GrpcClientCQTag(::grpc::ClientContext* context, StatusCallback cb)
+ : context_(context), cb_(cb) {}
+ ~GrpcClientCQTag() { delete context_; }
+
+ void OnCompleted(bool ok) {
+ if (!ok) {
+ VLOG(2) << "Call returned with non-ok status: "
+ << status_.error_message();
+ }
+ cb_(FromGrpcStatus(status_));
+ }
+
+ ::grpc::ClientContext* context() { return context_; }
+ ::grpc::Status* status() { return &status_; }
+
+ private:
+ ::grpc::ClientContext* context_;
+ ::grpc::Status status_;
+ StatusCallback cb_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcClientCQTag);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_CLIENT_CQ_TAG_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
new file mode 100644
index 0000000000..b8d50c5695
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -0,0 +1,181 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// GrpcMasterService implements the RPC service MasterSerivce.
+//
+// A GrpcMasterService maintains the state of live graph computation
+// sessions, each session orchestrates both local and remote devices
+// to carry out the graph computation.
+//
+// A GrpcMasterService knows ahead of time local devices available as
+// client devices.
+//
+// A GrpcMasterService discovers remote devices in the background and
+// keeps track of statistics of those remote devices.
+//
+// Each session analyses the graph, places nodes across available
+// devices, and ultimately drives the graph computation by initiating
+// RunGraph on workers.
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
+
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/distributed_runtime/master.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
+
+namespace tensorflow {
+
+class GrpcMasterService : public AsyncServiceInterface {
+ public:
+ GrpcMasterService(MasterEnv* env, ::grpc::ServerBuilder* builder)
+ : master_impl_(new Master(env, 0.0)) {
+ builder->RegisterService(&master_service_);
+ cq_ = builder->AddCompletionQueue().release();
+ }
+
+ ~GrpcMasterService() {
+ delete cq_;
+ delete master_impl_;
+ }
+
+// This macro creates a new request for the given RPC method name
+// (e.g., `ENQUEUE_REQUEST(RunStep);`), and enqueues it on
+// `this->cq_`.
+//
+// This macro is invoked one or more times for each RPC method to
+// ensure that there are sufficient completion queue entries to
+// handle incoming requests without blocking.
+//
+// The implementation of the request handler for each RPC method
+// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
+// to keep accepting new requests.
+#define ENQUEUE_REQUEST(method) \
+ do { \
+ Call<GrpcMasterService, grpc::MasterService::AsyncService, \
+ method##Request, method##Response>:: \
+ EnqueueRequest(&master_service_, cq_, \
+ &grpc::MasterService::AsyncService::Request##method, \
+ &GrpcMasterService::method##Handler); \
+ } while (0)
+
+ void HandleRPCsLoop() {
+ ENQUEUE_REQUEST(CreateSession);
+ ENQUEUE_REQUEST(ExtendSession);
+ for (int i = 0; i < 100; ++i) {
+ ENQUEUE_REQUEST(RunStep);
+ }
+ ENQUEUE_REQUEST(CloseSession);
+ ENQUEUE_REQUEST(ListDevices);
+ ENQUEUE_REQUEST(Reset);
+
+ void* tag;
+ bool ok;
+ while (cq_->Next(&tag, &ok)) {
+ CHECK(ok);
+ UntypedCall<GrpcMasterService>::Tag* callback_tag =
+ static_cast<UntypedCall<GrpcMasterService>::Tag*>(tag);
+ callback_tag->OnCompleted(this, ok);
+ delete callback_tag;
+ }
+ }
+
+ private:
+ Master* master_impl_; // Owned.
+ ::grpc::ServerCompletionQueue* cq_; // Owned.
+ grpc::MasterService::AsyncService master_service_;
+
+ template <class RequestMessage, class ResponseMessage>
+ using MasterCall = Call<GrpcMasterService, grpc::MasterService::AsyncService,
+ RequestMessage, ResponseMessage>;
+
+ // RPC handler for creating a session.
+ void CreateSessionHandler(
+ MasterCall<CreateSessionRequest, CreateSessionResponse>* call) {
+ master_impl_->CreateSession(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(CreateSession);
+ }
+
+ // RPC handler for extending a session.
+ void ExtendSessionHandler(
+ MasterCall<ExtendSessionRequest, ExtendSessionResponse>* call) {
+ master_impl_->ExtendSession(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(ExtendSession);
+ }
+
+ // RPC handler for running one step in a session.
+ void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
+ CallOptions* call_opts = new CallOptions;
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ master_impl_->RunStep(call_opts, &call->request, &call->response,
+ [call, call_opts](const Status& status) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(RunStep);
+ }
+
+ // RPC handler for deleting a session.
+ void CloseSessionHandler(
+ MasterCall<CloseSessionRequest, CloseSessionResponse>* call) {
+ master_impl_->CloseSession(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(CloseSession);
+ }
+
+ // RPC handler for listing devices.
+ void ListDevicesHandler(
+ MasterCall<ListDevicesRequest, ListDevicesResponse>* call) {
+ master_impl_->ListDevices(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(ListDevices);
+ }
+
+ // RPC handler for resetting all sessions.
+ void ResetHandler(MasterCall<ResetRequest, ResetResponse>* call) {
+ master_impl_->Reset(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(Reset);
+ }
+#undef ENQUEUE_REQUEST
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcMasterService);
+};
+
+AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
+ ::grpc::ServerBuilder* builder) {
+ CHECK(!env->local_devices.empty());
+ return new GrpcMasterService(env, builder);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
new file mode 100644
index 0000000000..d23a3f3ed3
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.h
@@ -0,0 +1,33 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
+
+namespace grpc {
+class ServerBuilder;
+} // namespace grpc
+
+namespace tensorflow {
+
+class AsyncServiceInterface;
+class MasterEnv;
+
+AsyncServiceInterface* NewGrpcMasterService(MasterEnv* env,
+ ::grpc::ServerBuilder* builder);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_MASTER_SERVICE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
new file mode 100644
index 0000000000..e358aed31f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -0,0 +1,79 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
+
+#include "tensorflow/core/distributed_runtime/master_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/protobuf/master_service.grpc.pb.h"
+
+namespace tensorflow {
+
+// GrpcRemoteMaster is an implementation of the MasterInterface
+// that uses gRPC to talk to the Master service.
+class GrpcRemoteMaster : public MasterInterface {
+ public:
+ explicit GrpcRemoteMaster(SharedGrpcChannelPtr client_channel)
+ : stub_(grpc::MasterService::NewStub(client_channel)) {}
+
+ ~GrpcRemoteMaster() override {}
+
+ Status CreateSession(const CreateSessionRequest* request,
+ CreateSessionResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->CreateSession(&ctx, *request, response));
+ }
+
+ Status ExtendSession(const ExtendSessionRequest* request,
+ ExtendSessionResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->ExtendSession(&ctx, *request, response));
+ }
+
+ Status RunStep(const RunStepRequest* request,
+ RunStepResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->RunStep(&ctx, *request, response));
+ }
+
+ Status CloseSession(const CloseSessionRequest* request,
+ CloseSessionResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->CloseSession(&ctx, *request, response));
+ }
+
+ Status ListDevices(const ListDevicesRequest* request,
+ ListDevicesResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->ListDevices(&ctx, *request, response));
+ }
+
+ Status Reset(const ResetRequest* request, ResetResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return FromGrpcStatus(stub_->Reset(&ctx, *request, response));
+ }
+
+ private:
+ std::unique_ptr<grpc::MasterService::Stub> stub_;
+};
+
+MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel) {
+ return new GrpcRemoteMaster(channel);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
new file mode 100644
index 0000000000..461e4ca0bd
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h
@@ -0,0 +1,27 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
+
+#include "tensorflow/core/distributed_runtime/master_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+
+namespace tensorflow {
+// Returns a MasterInterface wrapped around the gRPC channel `channel`.
+MasterInterface* NewGrpcMaster(SharedGrpcChannelPtr channel);
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_MASTER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
new file mode 100644
index 0000000000..0040631aac
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -0,0 +1,203 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
+
+#include "external/grpc/include/grpc++/grpc++.h"
+
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
+
+namespace tensorflow {
+
+class GrpcRemoteWorker : public WorkerInterface {
+ public:
+ explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
+ ::grpc::CompletionQueue* completion_queue,
+ WorkerCacheLogger* logger)
+ : stub_(grpc::WorkerService::NewStub(channel)),
+ cq_(completion_queue),
+ logger_(logger) {}
+
+ ~GrpcRemoteWorker() override {}
+
+ void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncGetStatus,
+ done);
+ }
+
+ void RegisterGraphAsync(const RegisterGraphRequest* request,
+ RegisterGraphResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response,
+ &grpc::WorkerService::Stub::AsyncRegisterGraph, done);
+ }
+
+ void DeregisterGraphAsync(const DeregisterGraphRequest* request,
+ DeregisterGraphResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response,
+ &grpc::WorkerService::Stub::AsyncDeregisterGraph, done);
+ }
+
+ void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
+ RunGraphResponse* response, StatusCallback done) override {
+ IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncRunGraph,
+ done, call_opts);
+ }
+
+ void CleanupGraphAsync(const CleanupGraphRequest* request,
+ CleanupGraphResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response,
+ &grpc::WorkerService::Stub::AsyncCleanupGraph, done);
+ }
+
+ void CleanupAllAsync(const CleanupAllRequest* request,
+ CleanupAllResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncCleanupAll,
+ done);
+ }
+
+ void RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request,
+ RecvTensorResponse* response,
+ TensorBufAllocator allocator,
+ StatusCallback done) override {
+ VLOG(1) << "RecvTensorAsync req: " << request->DebugString();
+ int64 start_usec = Env::Default()->NowMicros();
+ // Don't propagate dma_ok over gRPC.
+ RecvTensorRequest* req_copy = nullptr;
+ if (request->dma_ok()) {
+ req_copy = new RecvTensorRequest;
+ *req_copy = *request;
+ req_copy->set_dma_ok(false);
+ }
+ // Type-specialized logging for this method.
+ StatusCallback logging_callback = [this, request, req_copy, response, done,
+ start_usec](Status s) {
+ if (logger_->LoggingActive()) {
+ int64 end_usec = Env::Default()->NowMicros();
+ int64 step_id = request->step_id();
+ int64 bytes = response->tensor().ByteSize();
+ int64 send_start_usec = start_usec;
+ // If a send start time was reported by the other side, use
+ // that instead. Maybe we should mark the display if we're using
+ // our local time instead of the remote start time?
+ if (response->send_start_micros()) {
+ // send_start_micros is the timestamp taken when the remote
+ // machine began to send the RecvTensor response.
+ // Due to clock skew between source and dest machines, it is
+ // possible that send_start_micros can be larger than end_usec or
+ // less than start_usec.
+ // To respect causality, we enforce the invariants that the RecvTensor
+ // response can not have been sent before the RecvTensor request, and
+ // must have been sent before it was received.
+ send_start_usec = std::max(start_usec, response->send_start_micros());
+ send_start_usec = std::min(send_start_usec, end_usec - 1);
+ }
+ const string& key = request->rendezvous_key();
+ std::vector<string> key_parts = str_util::Split(key, ';');
+ if (key_parts.size() != 5) {
+ LOG(WARNING) << "Bad key: " << key;
+ } else {
+ logger_->RecordRecvTensor(step_id, send_start_usec, end_usec,
+ key_parts[3], // tensor name
+ key_parts[0], // src_device
+ key_parts[2], // dst_device
+ bytes);
+ }
+ }
+ VLOG(2) << "done callback, req: " << request->DebugString()
+ << " response " << response->DebugString();
+ delete req_copy;
+ done(s);
+ };
+
+ IssueRequest(req_copy ? req_copy : request, response,
+ &grpc::WorkerService::Stub::AsyncRecvTensor, logging_callback,
+ call_opts);
+ }
+
+ void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncLogging,
+ done);
+ }
+
+ void TracingAsync(const TracingRequest* request, TracingResponse* response,
+ StatusCallback done) override {
+ IssueRequest(request, response, &grpc::WorkerService::Stub::AsyncTracing,
+ done);
+ }
+
+ private:
+ template <class RequestMessage, class ResponseMessage>
+ using AsyncMethod =
+ std::unique_ptr<::grpc::ClientAsyncResponseReader<ResponseMessage>> (
+ grpc::WorkerService::Stub::*)(::grpc::ClientContext*,
+ const RequestMessage&,
+ ::grpc::CompletionQueue*);
+
+ // Utility method for issuing a generic asynchronous request. The
+ // given callback, `done`, will be called when the RPC completes.
+ template <class RequestMessage, class ResponseMessage>
+ void IssueRequest(const RequestMessage* request, ResponseMessage* response,
+ AsyncMethod<RequestMessage, ResponseMessage> async_method,
+ StatusCallback done, CallOptions* call_opts = nullptr) {
+ ::grpc::ClientContext* context = new ::grpc::ClientContext;
+ if (call_opts) {
+ call_opts->SetCancelCallback([context]() { context->TryCancel(); });
+ }
+ auto rpc = (stub_.get()->*async_method)(context, *request, cq_).release();
+ GrpcClientCQTag* tag =
+ new GrpcClientCQTag(context, [rpc, done, call_opts](Status s) {
+ if (call_opts) {
+ call_opts->ClearCancelCallback();
+ }
+ delete rpc;
+ done(s);
+ });
+ rpc->Finish(response, tag->status(), tag);
+ }
+
+ std::unique_ptr<grpc::WorkerService::Stub> stub_;
+ ::grpc::CompletionQueue* cq_;
+
+ // Support for logging.
+ WorkerCacheLogger* logger_;
+ bool retry_unavailable_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
+};
+
+WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
+ ::grpc::CompletionQueue* completion_queue,
+ WorkerCacheLogger* logger) {
+ return new GrpcRemoteWorker(channel, completion_queue, logger);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
new file mode 100644
index 0000000000..dfb72bdde2
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
@@ -0,0 +1,38 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#define THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+
+#include <memory>
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+
+namespace grpc {
+class CompletionQueue;
+}
+
+namespace tensorflow {
+
+class WorkerCacheLogger;
+class WorkerInterface;
+
+WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
+ ::grpc::CompletionQueue* completion_queue,
+ WorkerCacheLogger* logger);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
new file mode 100644
index 0000000000..ddac7fd2cd
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -0,0 +1,116 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+#include <memory>
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "external/grpc/include/grpc++/security/credentials.h"
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+#include "tensorflow/core/distributed_runtime/master_env.h"
+#include "tensorflow/core/distributed_runtime/master_session.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+namespace tensorflow {
+
+void StartTensorFlowServer(const GrpcServerOptions& options) {
+ // The Thread destructor waits until all the thread terminates is
+ // done (i.e. forever).
+ std::unique_ptr<Thread> launcher_thread(Env::Default()->StartThread(
+ ThreadOptions(), "TF_service_launcher", [options]() {
+ // Configure the MasterEnv and WorkerEnv, which provide service-global
+ // context for the master and worker services, respectively.
+
+ // The master and worker share the same worker cache (for RPC
+ // connections to other workers) and devices (so that the master
+ // may run some ops locally as a "client" device). The master
+ // requires a device to serve as a "client device", so that remote
+ // devices can copy the feeds from the master.
+ MasterEnv master_env;
+ WorkerEnv worker_env;
+ master_env.env = Env::Default();
+ worker_env.env = Env::Default();
+
+ // Configure shared devices between master and worker.
+ string name_prefix =
+ strings::StrCat("/job:", options.job_name, "/replica:0", "/task:",
+ options.task_index);
+ DeviceFactory::AddDevices(options.default_session_options, name_prefix,
+ &master_env.local_devices);
+ worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
+ string unused;
+ CHECK(DeviceNameUtils::SplitDeviceName(
+ master_env.local_devices[0]->name(), &worker_env.worker_name,
+ &unused));
+
+ GrpcChannelCache* channel_cache =
+ NewGrpcChannelCache(options.channel_spec);
+ int port;
+ const std::vector<string> host_port =
+ str_util::Split(channel_cache->TranslateTask(name_prefix), ':');
+ CHECK(str_util::NumericParse32(host_port[1], &port));
+
+ worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
+
+ // Finish setting up master environment.
+ master_env.ops = OpRegistry::Global();
+ master_env.worker_cache = worker_env.worker_cache;
+ master_env.master_session_factory = internal::NewMasterSession;
+
+ // Finish setting up worker environment.
+ worker_env.graph_mgr = new GraphMgr(&worker_env);
+ worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
+ worker_env.compute_pool = ComputePool(options.default_session_options);
+
+ // Build the gRPC server that will host both the master and the
+ // worker services.
+ ::grpc::ServerBuilder builder;
+ builder.AddListeningPort(strings::StrCat("0.0.0.0:", port),
+ ::grpc::InsecureServerCredentials());
+ auto master_service = NewGrpcMasterService(&master_env, &builder);
+ auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
+ auto server_ = builder.BuildAndStart();
+
+ // Start threads to handle the incoming RPCs for the master and worker.
+ // NOTE(mrry): The Thread destructor waits until the thread terminates
+ // (i.e. forever in this case).
+ std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
+ ThreadOptions(), "TF_master_service",
+ [master_service]() { master_service->HandleRPCsLoop(); }));
+ std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
+ ThreadOptions(), "TF_worker_service",
+ [worker_service]() { worker_service->HandleRPCsLoop(); }));
+ }));
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
new file mode 100644
index 0000000000..59abb31a15
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -0,0 +1,53 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+// Defines the configuration for a single task (typically a process)
+// that is part of a gRPC-based TensorFlow cluster.
+struct GrpcServerOptions {
+ // This identity of the job to which this task belongs. The names
+ // of the devices in this task will be prefixed with
+ // "/job:<job_name>/task:<task_index>"
+ string job_name;
+ int32 task_index = 0;
+
+ // A channel specification, which defines (i) the set of jobs that
+ // comprise the cluster, and (ii) within each job, the endpoints
+ // exposed by each task. NOTE: This spec also defines the endpoint
+ // on which this task will listen.
+ GrpcChannelSpec channel_spec;
+
+ // SessionOptions that will be used as defaults when configuring
+ // sessions in this task. `default_session_options.target` is
+ // ignored.
+ SessionOptions default_session_options;
+};
+
+// Starts a gRPC-based TensorFlow server with the given options.
+// This function will not return.
+void StartTensorFlowServer(const GrpcServerOptions& options);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SERVER_LIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
new file mode 100644
index 0000000000..6924fc5537
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -0,0 +1,233 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/common_runtime/session_factory.h"
+#include "tensorflow/core/distributed_runtime/master_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_master.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+
+namespace tensorflow {
+
+const size_t kSchemePrefix = sizeof("grpc://") - 1;
+
+GrpcSession::GrpcSession(const SessionOptions& options)
+ : options_(options),
+ master_(NewGrpcMaster(
+ NewHostPortGrpcChannel(options.target.substr(kSchemePrefix)))),
+ current_graph_version_(-1) {}
+
+GrpcSession::~GrpcSession() {}
+
+namespace {
+// Re-encodes constant represented in tensor proto into
+// tensor_content, which is slightly better (less copies and lower peak
+// memory usage) when used with rpc subsystems.
+void ReEncodeConsts(GraphDef* gdef) {
+ for (NodeDef& ndef : *(gdef->mutable_node())) {
+ if (ndef.op() == "Const") {
+ TensorProto* proto = nullptr;
+ for (auto& attr : *ndef.mutable_attr()) {
+ if (attr.first == "value") {
+ proto = attr.second.mutable_tensor();
+ }
+ }
+ if (proto != nullptr && proto->tensor_content().empty() &&
+ proto->ByteSize() > 64) {
+ // If the constant is encoded with repeated proto fields and
+ // it is moderate large, we re-encode it in tensor_content as
+ // a Cord. This is mildly helpful for reducing the peak memory
+ // usage on the server side where GraphDef/NodeDef are copied
+ // quite often.
+ Tensor parsed(proto->dtype());
+ if (parsed.FromProto(*proto)) {
+ parsed.AsProtoTensorContent(proto);
+ }
+ }
+ }
+ }
+}
+} // namespace
+
+Status GrpcSession::Create(const GraphDef& graph) {
+ if (!handle_.empty()) {
+ return errors::InvalidArgument("A session is alive.");
+ }
+ CreateSessionRequest req;
+ *req.mutable_config() = options_.config;
+ *req.mutable_graph_def() = graph;
+ ReEncodeConsts(req.mutable_graph_def());
+ CreateSessionResponse resp;
+ Status s = master_->CreateSession(&req, &resp);
+ if (s.ok()) {
+ mutex_lock l(mu_);
+ swap(handle_, *(resp.mutable_session_handle()));
+ current_graph_version_ = resp.graph_version();
+ }
+ return s;
+}
+
+Status GrpcSession::Extend(const GraphDef& graph) {
+ if (handle_.empty()) {
+ // Session was unitialized, so simply initialize the session with 'graph'.
+ return Create(graph);
+ }
+ mutex_lock l(mu_);
+ ExtendSessionRequest req;
+ req.set_session_handle(handle_);
+ *req.mutable_graph_def() = graph;
+ req.set_current_graph_version(current_graph_version_);
+ ExtendSessionResponse resp;
+ Status s = master_->ExtendSession(&req, &resp);
+ if (s.ok()) {
+ current_graph_version_ = resp.new_graph_version();
+ }
+ return s;
+}
+
+Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) {
+ // Convert to proto
+ RunStepRequest req;
+ RunStepResponse resp;
+
+ for (const auto& it : inputs) {
+ Tensor input_tensor = it.second;
+ auto feed = req.add_feed();
+ feed->set_name(it.first);
+ TensorProto* proto = feed->mutable_tensor();
+ input_tensor.AsProtoTensorContent(proto);
+ }
+
+ // Build an index from fetch tensor name to offset.
+ std::unordered_map<string, int> output_name_to_offset;
+ for (const string& output_name : output_names) {
+ req.add_fetch(output_name);
+ output_name_to_offset.insert(
+ std::make_pair(output_name, output_name_to_offset.size()));
+ }
+ for (const string& target : target_nodes) {
+ req.add_target(target);
+ }
+
+ TF_RETURN_IF_ERROR(RunProto(&req, &resp));
+
+ if (!output_names.empty()) {
+ outputs->resize(output_names.size());
+ }
+
+ // Convert response back to Tensors in the correct order.
+ for (const NamedTensorProto& tensor : resp.tensor()) {
+ auto fetch_it = output_name_to_offset.find(tensor.name());
+ if (fetch_it == output_name_to_offset.end()) {
+ return errors::Internal("Received response for unrequested fetch: ",
+ tensor.name());
+ }
+
+ Tensor output;
+ if (!output.FromProto(tensor.tensor())) {
+ return errors::InvalidArgument("Could not parse returned proto for ",
+ tensor.name());
+ }
+
+ (*outputs)[fetch_it->second] = output;
+ }
+
+ return Status::OK();
+}
+
+Status GrpcSession::RunProto(RunStepRequest* req, RunStepResponse* resp) {
+ if (handle_.empty()) {
+ return errors::InvalidArgument("A session is not created yet....");
+ }
+
+ req->set_session_handle(handle_);
+ return master_->RunStep(req, resp);
+}
+
+Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) {
+ return errors::Internal("Partial run is not supported for remote session.");
+}
+
+Status GrpcSession::PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) {
+ return errors::Internal("Partial run is not supported for remote session.");
+}
+
+Status GrpcSession::Close() {
+ if (handle_.empty()) {
+ return errors::InvalidArgument("A session is not created yet....");
+ }
+ CloseSessionRequest req;
+ req.set_session_handle(handle_);
+ handle_.clear();
+ CloseSessionResponse resp;
+ return master_->CloseSession(&req, &resp);
+}
+
+std::vector<DeviceAttributes> GrpcSession::ListDevices() {
+ std::vector<DeviceAttributes> devices;
+
+ ListDevicesRequest req;
+ ListDevicesResponse resp;
+ Status s = master_->ListDevices(&req, &resp);
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not list devices: " << s;
+ return devices;
+ }
+
+ for (const auto& device_attr : resp.local_device()) {
+ devices.push_back(device_attr);
+ }
+ for (const auto& device_attr : resp.remote_device()) {
+ devices.push_back(device_attr);
+ }
+
+ return devices;
+}
+
+class GrpcSessionFactory : public SessionFactory {
+ public:
+ bool AcceptsOptions(const SessionOptions& options) override {
+ return StringPiece(options.target).starts_with("grpc://");
+ }
+
+ Session* NewSession(const SessionOptions& options) override {
+ return new GrpcSession(options);
+ }
+};
+
+class GrpcSessionRegistrar {
+ public:
+ GrpcSessionRegistrar() {
+ SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory());
+ }
+};
+static GrpcSessionRegistrar registrar;
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
new file mode 100644
index 0000000000..9bc6034ba6
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -0,0 +1,97 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+class MasterInterface;
+
+// A Session instance lets the caller drive a TensorFlow graph
+// computation on potentially remote sets of devices. This is a thin
+// wrapper around tensorflow::grpc::MasterService.
+//
+// Multiple threads must synchronize their accesses to a single
+// session.
+class GrpcSession : public Session {
+ public:
+ // Do not use; just present for easier swig wrapping.
+ explicit GrpcSession(const SessionOptions& options);
+
+ ~GrpcSession() override;
+
+ // Creates a session with the "target". The session carries out
+ // the graph computation defined by "graph", and will have version
+ // number "initial_version".
+ Status Create(const GraphDef& graph) override;
+
+ Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ std::vector<Tensor>* outputs) override;
+
+ Status Extend(const GraphDef& graph) override;
+ Status Close() override;
+
+ // NOTE: This API is still experimental and may change.
+ ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
+
+ // NOTE: This API is still experimental and may change.
+ ::tensorflow::Status PRun(
+ const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
+
+ std::vector<DeviceAttributes> ListDevices();
+
+ private:
+ SessionOptions options_;
+ std::unique_ptr<MasterInterface> master_;
+ mutex mu_;
+
+ // handle_ returned by the master to identify this session.
+ string handle_;
+
+ // The current version of the graph.
+ int64 current_graph_version_ GUARDED_BY(mu_);
+
+ Status RunProto(RunStepRequest* req, RunStepResponse* resp);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcSession);
+};
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_SESSION_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
new file mode 100644
index 0000000000..86a9b07c2c
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -0,0 +1,750 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/graph/default_device.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/master.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/port.h"
+
+namespace tensorflow {
+
+static SessionOptions Devices(int num_cpus, int num_gpus) {
+ SessionOptions result;
+ (*result.config.mutable_device_count())["CPU"] = num_cpus;
+ (*result.config.mutable_device_count())["GPU"] = num_gpus;
+ return result;
+}
+
+void CreateGraphDef(GraphDef* graph_def, string node_names[3]) {
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({1, 2}));
+ test::FillValues<float>(&a_tensor, {1, 2});
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ node_names[0] = a->name();
+
+ Tensor b_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&b_tensor, {2, 1});
+ Node* b = test::graph::Constant(&graph, b_tensor);
+ node_names[1] = b->name();
+
+ Node* c = test::graph::Matmul(&graph, a, b, false, false);
+ node_names[2] = c->name();
+
+ test::graph::ToGraphDef(&graph, graph_def);
+}
+
+// Asserts that "val" is a single float tensor. The only float is
+// "expected_val".
+static void IsSingleFloatValue(const Tensor& val, float expected_val) {
+ ASSERT_EQ(val.dtype(), DT_FLOAT);
+ ASSERT_EQ(val.NumElements(), 1);
+ ASSERT_EQ(val.flat<float>()(0), expected_val);
+}
+
+static SessionOptions Options(const string& target, int placement_period) {
+ SessionOptions options;
+ // NOTE(mrry): GrpcSession requires a grpc:// scheme prefix in the target
+ // string.
+ options.target = strings::StrCat("grpc://", target);
+ options.config.set_placement_period(placement_period);
+ return options;
+}
+
+static Session* NewRemote(const SessionOptions& options) {
+ return CHECK_NOTNULL(NewSession(options));
+}
+
+TEST(GrpcSessionTest, BasicNonProtoAPI) {
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ for (int iters = 0; iters < 25; ++iters) {
+ TF_CHECK_OK(session->Create(graph));
+ {
+ std::vector<std::pair<string, Tensor>> inputs;
+ TF_CHECK_OK(session->Run(inputs, {}, {}, {}));
+ }
+ {
+ // Just run to target node
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> targets = {node_names[2]};
+ TF_CHECK_OK(session->Run(inputs, {}, targets, nullptr));
+ }
+ {
+ // Run to a target node and a real tensor
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> names = {node_names[2] + ":0"};
+ std::vector<string> targets = {node_names[1]};
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run(inputs, names, targets, &outputs));
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
+ }
+
+ TF_CHECK_OK(session->Close());
+ }
+}
+
+TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+ ASSERT_TRUE(session->Create(graph).ok());
+
+ // Test that the order of the output names matches the order of the
+ // returned Tensors.
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> names = {node_names[2] + ":0", node_names[0] + ":0",
+ node_names[1] + ":0"};
+
+ std::vector<string> target_ops = {node_names[1]};
+ std::vector<Tensor> outputs;
+ ASSERT_TRUE(session->Run(inputs, names, target_ops, &outputs).ok());
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
+ ASSERT_TRUE(outputs[1].IsInitialized());
+ ASSERT_EQ(1.0, outputs[1].flat<float>()(0));
+ ASSERT_TRUE(outputs[2].IsInitialized());
+ ASSERT_EQ(2.0, outputs[2].flat<float>()(0));
+ ASSERT_TRUE(session->Close().ok());
+}
+
+TEST(GrpcSessionTest, NonLocalWithFilters) {
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ SessionOptions options;
+ options.target = strings::StrCat("grpc://", cluster->targets()[0]);
+ options.config.add_device_filters(cluster->devices()[0].name());
+
+ std::unique_ptr<Session> session(NewRemote(options));
+ ASSERT_TRUE(session != nullptr);
+
+ {
+ GraphDef graph_copy(graph);
+ graph::SetDefaultDevice(cluster->devices()[0].name(), &graph_copy);
+ TF_CHECK_OK(session->Create(graph_copy));
+ TF_CHECK_OK(session->Run({}, {}, {}, nullptr));
+ TF_CHECK_OK(session->Close());
+ }
+ {
+ GraphDef graph_copy(graph);
+ graph::SetDefaultDevice(cluster->devices()[1].name(), &graph_copy);
+ TF_CHECK_OK(session->Create(graph_copy));
+ auto status = session->Run({}, {}, {}, nullptr);
+ EXPECT_EQ(tensorflow::error::INVALID_ARGUMENT, status.code());
+ TF_CHECK_OK(session->Close());
+ }
+}
+
+// A = [3 2; -1 0]; x = rand(2, 1); We want to compute the largest
+// eigenvalue for A, which is 2.0. Iteratively, we do
+// repeat x = y / y.norm(); y = A * x; end
+// At the end, we expect "lambda" converges to 2.0.
+void FindMaxEigen(const string& target) {
+ Graph graph(OpRegistry::Global());
+
+ Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
+ // Store rows [3, 2] and [-1, 0] in row major format.
+ test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
+ Node* a = test::graph::Constant(&graph, a_tensor);
+
+ // x is from the feed.
+ Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
+ test::FillValues<float>(&x_tensor, {0, 0});
+ Node* x = test::graph::Constant(&graph, x_tensor);
+
+ // y = A * x
+ Node* y = test::graph::Matmul(&graph, a, x, false, false);
+
+ // y2 = y.^2
+ Node* y2 = test::graph::Unary(&graph, "Square", y);
+
+ // const tensor for reduction
+ Tensor rdim_tensor(DT_INT32, TensorShape({}));
+ rdim_tensor.scalar<int32>()() = 0;
+ Node* rdim = test::graph::Constant(&graph, rdim_tensor);
+
+ // y2_sum = sum(y2)
+ Node* y2_sum = test::graph::Reduce(&graph, "Sum", y2, rdim);
+
+ // y_norm = sqrt(y2_sum)
+ Node* y_norm = test::graph::Unary(&graph, "Sqrt", y2_sum);
+
+ // y_normalized = y ./ y_norm
+ Node* y_normalized = test::graph::Binary(&graph, "Div", y, y_norm);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ std::unique_ptr<Session> session(NewRemote(Options(target, 1)));
+ ASSERT_TRUE(session != nullptr);
+ TF_CHECK_OK(session->Create(def));
+
+ // Setup feeds and fetches.
+ float lambda;
+ Tensor feed_value(DT_FLOAT, TensorShape({2, 1}));
+ feed_value.matrix<float>()(0, 0) = -3.1415;
+ feed_value.matrix<float>()(1, 0) = +2.7183;
+
+ for (int i = 0; i < 25; ++i) {
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run({{x->name(), feed_value}},
+ {y->name(), y_normalized->name()}, {}, &outputs));
+ const Tensor& y = outputs[0];
+ const Tensor& y_normalized = outputs[1];
+ // Print out lambda, x, and y.
+ CHECK_EQ(2, feed_value.NumElements());
+ CHECK_EQ(2, y.NumElements());
+ lambda = y.flat<float>()(0) / feed_value.flat<float>()(0);
+ printf("%06d lambda = %8.6f x = [%8.6f %8.6f] y = [%8.6f %8.6f]\n", i,
+ lambda, feed_value.flat<float>()(0), feed_value.flat<float>()(1),
+ y.flat<float>()(0), y.flat<float>()(1));
+ // Copies y_normalized to *x.
+ feed_value = y_normalized;
+ }
+ EXPECT_NEAR(2.0, lambda, 1e-6);
+}
+
+TEST(FindMaxEigenTest, RemoteDevice) {
+ std::unique_ptr<test::TestCluster> cluster;
+ test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster);
+ FindMaxEigen(cluster->targets()[0]);
+}
+
+void SetDevice(GraphDef* graph, const string& name, const string& dev) {
+ for (int i = 0; i < graph->node_size(); ++i) {
+ if (graph->node(i).name() == name) {
+ graph->mutable_node(i)->set_device(dev);
+ return;
+ }
+ }
+ LOG(FATAL) << "Name '" << name << "' not found.";
+}
+
+TEST(GrpcSessionTest, MultiDevices) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ Graph graph(OpRegistry::Global());
+ const int kSize = 1048576;
+
+ // c = a * b = 2 * 3 * kSize
+ Tensor a_tensor(DT_FLOAT, TensorShape({1, kSize}));
+ Tensor b_tensor(DT_FLOAT, TensorShape({kSize, 1}));
+ for (int i = 0; i < kSize; ++i) {
+ a_tensor.flat<float>()(i) = 2;
+ b_tensor.flat<float>()(i) = 3;
+ }
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ Node* b = test::graph::Constant(&graph, b_tensor);
+ Node* c = test::graph::Matmul(&graph, a, b, false, false);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ // In this test, we force each node (a, b, c) on every possible device.
+ // We test all possible cases.
+ for (const auto& a_dev : cluster->devices()) {
+ for (const auto& b_dev : cluster->devices()) {
+ for (const auto& c_dev : cluster->devices()) {
+ LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name()
+ << " c: " << c_dev.name();
+
+ SetDevice(&def, a->name(), a_dev.name());
+ SetDevice(&def, b->name(), b_dev.name());
+ SetDevice(&def, c->name(), c_dev.name());
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1000)));
+ ASSERT_TRUE(session != nullptr);
+ TF_CHECK_OK(session->Create(def));
+ {
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
+ ASSERT_EQ(1, outputs.size());
+ IsSingleFloatValue(outputs[0], 6.0 * kSize);
+ }
+ TF_CHECK_OK(session->Close());
+ }
+ }
+ }
+}
+
+TEST(GrpcSessionTest, MultiDevices_String) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 1), 2, &cluster));
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1000)));
+ ASSERT_TRUE(session != nullptr);
+
+ // b = a
+ Graph graph(OpRegistry::Global());
+ Tensor a_tensor(DT_STRING, TensorShape({2, 2}));
+ for (int i = 0; i < 4; ++i) {
+ a_tensor.flat<string>()(i) = "hello, world";
+ }
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ Node* b = test::graph::Identity(&graph, a);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ // In this test, we force each node (a, b) on every possible device.
+ // We test all possible cases.
+ for (const auto& a_dev : cluster->devices()) {
+ for (const auto& b_dev : cluster->devices()) {
+ LOG(INFO) << "a: " << a_dev.name() << " b: " << b_dev.name();
+ SetDevice(&def, a->name(), a_dev.name());
+ SetDevice(&def, b->name(), b_dev.name());
+
+ TF_CHECK_OK(session->Create(def));
+ {
+ std::vector<Tensor> outputs;
+ Status s = session->Run({}, {b->name()}, {}, &outputs);
+ if (s.ok()) {
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_EQ(outputs[0].dtype(), DT_STRING);
+ ASSERT_EQ(outputs[0].NumElements(), 4);
+ for (int i = 0; i < outputs[0].NumElements(); ++i) {
+ EXPECT_EQ(outputs[0].flat<string>()(i), "hello, world");
+ }
+ } else {
+ LOG(ERROR) << "Error: " << s;
+ ASSERT_TRUE((a_dev.device_type() == DEVICE_GPU) ||
+ (b_dev.device_type() == DEVICE_GPU));
+ ASSERT_FALSE(s.ok());
+ }
+ }
+ TF_CHECK_OK(session->Close());
+ }
+ }
+}
+
+TEST(GrpcSessionTest, SendRecv_Node_Naming) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 3, &cluster));
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ // This test case needs at least 3 devices.
+ CHECK_GE(cluster->devices().size(), 3);
+ const DeviceAttributes& src = cluster->devices()[0];
+ const DeviceAttributes& dst0 = cluster->devices()[1];
+ const DeviceAttributes& dst1 = cluster->devices()[2];
+ LOG(INFO) << "src = " << src.name() << " dst0 = " << dst0.name()
+ << " dst1 = " << dst1.name();
+
+ // Within the same session, we compute two subgraphs:
+ // 1) a on 'src' sends to b on 'dst0';
+ // 2) a on 'src' sends to c on 'dst1'.
+ Graph graph(OpRegistry::Global());
+ Tensor a_tensor(DT_FLOAT, TensorShape({1, 1}));
+ a_tensor.flat<float>()(0) = 100;
+ Node* a = test::graph::Constant(&graph, a_tensor);
+ Node* b = test::graph::Identity(&graph, a);
+ Node* c = test::graph::Identity(&graph, a);
+
+ GraphDef def;
+ test::graph::ToGraphDef(&graph, &def);
+
+ // The base graph have a, b, c, assigned to devices explicitly.
+ SetDevice(&def, a->name(), src.name());
+ SetDevice(&def, b->name(), dst0.name());
+ SetDevice(&def, c->name(), dst1.name());
+ TF_CHECK_OK(session->Create(def));
+
+ // Run subgraph a -> b, and fetch b.
+ {
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run({}, {b->name()}, {}, &outputs));
+ ASSERT_EQ(1, outputs.size());
+ IsSingleFloatValue(outputs[0], 100);
+ }
+
+ // Run subgraph a -> c, and fetch c.
+ {
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->Run({}, {c->name()}, {}, &outputs));
+ ASSERT_EQ(1, outputs.size());
+ IsSingleFloatValue(outputs[0], 100);
+ }
+
+ TF_CHECK_OK(session->Close());
+}
+
+TEST(GrpcSessionTest, Error) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+ const string& master = cluster->targets()[0];
+ const string& dev_a = cluster->devices()[0].name();
+ const string& dev_b = cluster->devices()[1].name();
+ LOG(INFO) << "master " << master << "dev_a " << dev_a << "dev_b " << dev_b;
+ GraphDef gdef;
+ std::vector<string> fetches;
+ {
+ Graph g(OpRegistry::Global());
+
+ // a2 = a + error(a)
+ //
+ // Subgraph for "a" fails. The master will cancel the subgraph for
+ // "b" and then returns the Session::Run.
+ auto a = test::graph::Constant(&g, Tensor());
+ a->set_assigned_device_name(dev_a);
+ auto a_err = test::graph::Error(&g, a, "fantasia!");
+ a_err->set_assigned_device_name(dev_a);
+ auto a2 = test::graph::Add(&g, a, a_err);
+ a2->set_assigned_device_name(dev_a);
+ fetches.push_back(a2->name());
+
+ // b2 = b + delay(b)
+ //
+ // Subgraph for "b" sleeps at the node "b_delay". When the sleep
+ // finishes, the subgraph "b" will continue execution till it
+ // notices that it is cancelled. Meanwhile, subgraph's executor
+ // and its related state (registered ops) should still be alive.
+ auto b = test::graph::Constant(&g, Tensor());
+ b->set_assigned_device_name(dev_b);
+ auto b_delay = test::graph::Delay(&g, b, Microseconds(1000000));
+ b_delay->set_assigned_device_name(dev_b);
+ auto b2 = test::graph::Add(&g, b, b_delay);
+ b2->set_assigned_device_name(dev_b);
+ fetches.push_back(b2->name());
+ test::graph::ToGraphDef(&g, &gdef);
+ }
+ std::unique_ptr<Session> session(NewRemote(Options(master, 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ TF_CHECK_OK(session->Create(gdef));
+ {
+ Status status = session->Run({}, fetches, {}, nullptr);
+ EXPECT_FALSE(status.ok());
+ EXPECT_NE(status.ToString().find("fantasia!"), string::npos);
+ }
+ // session->Close() shall clean up all states related to the session->
+ // E.g., deregisters subgraph with workers, etc.
+ TF_CHECK_OK(session->Close());
+
+ // Sleep a bit so that most of asynchronous works finishes before
+ // the test process finishes.
+ Env::Default()->SleepForMicroseconds(2000000);
+}
+
+TEST(SessionTest, SharedVar) {
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 1, &cluster));
+ const string master = cluster->targets()[0];
+ CHECK_EQ(cluster->devices().size(), 1);
+
+ GraphDef gdef;
+ string init_name;
+ string inc_name;
+ string get_name;
+ {
+ Graph g(OpRegistry::Global());
+ Tensor one(DT_FLOAT, TensorShape({}));
+ one.scalar<float>()() = 1.0;
+ Node* var = test::graph::Var(&g, DT_FLOAT, one.shape());
+ Node* init = test::graph::Assign(&g, var, test::graph::Constant(&g, one));
+ init_name = init->name();
+ Node* update = test::graph::Assign(
+ &g, var, test::graph::Add(&g, var, test::graph::Constant(&g, one)));
+ inc_name = update->name();
+ get_name = var->name();
+ test::graph::ToGraphDef(&g, &gdef);
+ }
+
+ // Init a variable
+ {
+ Session* sess = NewRemote(Options(master, 1));
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<std::pair<string, Tensor>> inp;
+ TF_CHECK_OK(sess->Run(inp, {}, {init_name}, nullptr));
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ }
+
+ for (int rep = 1; rep < 10; ++rep) {
+ // Update a variable
+ {
+ Session* sess = NewRemote(Options(master, 1));
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<std::pair<string, Tensor>> inp;
+ TF_CHECK_OK(sess->Run(inp, {}, {inc_name}, nullptr));
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ }
+
+ // Gets the variable's value.
+ {
+ Session* sess = NewRemote(Options(master, 1));
+ TF_CHECK_OK(sess->Create(gdef));
+ std::vector<std::pair<string, Tensor>> inp;
+ std::vector<Tensor> ret;
+ TF_CHECK_OK(sess->Run(inp, {get_name}, {}, &ret));
+ ASSERT_EQ(ret.size(), 1);
+ EXPECT_EQ(ret[0].scalar<float>()(), 1.0 * (1 + rep));
+ TF_CHECK_OK(sess->Close());
+ delete sess;
+ }
+ }
+}
+
+void CreateInvalidGraph(const string& graph_def_ascii,
+ const string& error_substring) {
+ GraphDef graph;
+ CHECK(protobuf::TextFormat::ParseFromString(graph_def_ascii, &graph));
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ Status s = session->Create(graph);
+
+ ASSERT_FALSE(s.ok());
+ EXPECT_NE(s.error_message().find(error_substring), string::npos);
+}
+
+TEST(SessionTest, InvalidOpName) {
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a:b' op: 'Const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ )",
+ "Illegal op name");
+
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a:0' op: 'Const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ )",
+ "Illegal op name");
+
+ CreateInvalidGraph(R"(
+ node {
+ name: '_a' op: 'Const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ )",
+ "Illegal op name");
+}
+
+TEST(SessionTest, InvalidOpInputName) {
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a' op: 'const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ node {
+ name:'b' op:'MatMul' input:'a:first' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ "Illegal op input name");
+
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a' op: 'const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ node {
+ name:'b' op:'MatMul' input:'_a' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ "Illegal op input name");
+
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a' op: 'const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ node {
+ name:'b' op:'MatMul' input:'_a:0' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ "Illegal op input name");
+
+ CreateInvalidGraph(R"(
+ node {
+ name: 'a' op: 'const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ node {
+ name:'b' op:'MatMul' input:'a:01' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ "Illegal op input name");
+}
+
+TEST(SessionTest, ExtendValidation) {
+ GraphDef graph;
+ bool success = protobuf::TextFormat::ParseFromString(R"(
+ node {
+ name: 'a' op: 'Const'
+ attr { key: 'dtype' value { type: DT_FLOAT } }
+ attr { key: 'value' value {
+ tensor { dtype: DT_FLOAT tensor_shape { dim [{size:1}, {size:1}] }
+ float_val: [100] }
+ } }
+ }
+ )",
+ &graph);
+ // NOTE(mrry): CHECK not done inline to avoid a compilation error in
+ // open-source (due to a multi-line string in a macro argument).
+ ASSERT_TRUE(success);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ TF_CHECK_OK(session->Create(graph));
+
+ // 1. Fail with an unknown input name.
+ GraphDef extension;
+ success = protobuf::TextFormat::ParseFromString(R"(
+ node {
+ name:'b' op:'MatMul' input:'a:first' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ &extension);
+ ASSERT_TRUE(success);
+
+ Status s = session->Extend(extension);
+ ASSERT_FALSE(s.ok());
+ EXPECT_NE(s.error_message().find("Illegal op input name"), string::npos);
+
+ // 2. Succeed with a valid node.
+ success = protobuf::TextFormat::ParseFromString(R"(
+ node {
+ name:'b' op:'MatMul' input:'a' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ &extension);
+ ASSERT_TRUE(success);
+ TF_CHECK_OK(session->Extend(extension));
+
+ // 2. Fail with a duplicate node.
+ success = protobuf::TextFormat::ParseFromString(R"(
+ node {
+ name:'b' op:'MatMul' input:'a' input:'a'
+ attr { key: 'T' value { type: DT_FLOAT } }
+ attr { key: 'transpose_a' value { b: false } }
+ attr { key: 'transpose_b' value { b: false } }
+ attr { key: '_kernel' value { s: 'eigen' } }
+ }
+ )",
+ &extension);
+ ASSERT_TRUE(success);
+ s = session->Extend(extension);
+ ASSERT_FALSE(s.ok());
+ EXPECT_NE(s.error_message().find("'b', which was created by a previous call"),
+ string::npos);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
new file mode 100644
index 0000000000..51f24fdbf6
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server.cc
@@ -0,0 +1,98 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "external/grpc/include/grpc++/security/credentials.h"
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+// This binary starts a TensorFlow server (master and worker).
+namespace tensorflow {
+namespace {
+
+Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
+ string cluster_spec;
+ const bool parse_result =
+ ParseFlags(&argc, argv, {Flag("cluster_spec", &cluster_spec), //
+ Flag("job_name", &options->job_name), //
+ Flag("task_id", &options->task_index)});
+ if (!parse_result) {
+ return errors::InvalidArgument("Error parsing command-line flags");
+ }
+
+ size_t my_num_tasks = 0;
+ for (const string& job_str : str_util::Split(cluster_spec, ',')) {
+ // Split each entry in the flag into 3 pieces, separated by "|".
+ const std::vector<string> job_pieces = str_util::Split(job_str, '|');
+ CHECK_EQ(2, job_pieces.size()) << job_str;
+ const string& job = job_pieces[0];
+ // Does a bit more validation of the tasks_per_replica.
+ const StringPiece spec = job_pieces[1];
+ // job_str is of form <job_name>|<host_ports>.
+ const std::vector<string> host_ports = str_util::Split(spec, ';');
+ size_t num_tasks = host_ports.size();
+ if (job == options->job_name) {
+ my_num_tasks = num_tasks;
+ }
+ TF_RETURN_IF_ERROR(
+ options->channel_spec.AddHostPortsJob(job, host_ports, num_tasks));
+ LOG(INFO) << "Peer " << job << " " << num_tasks << " {"
+ << str_util::Join(host_ports, ", ") << "}";
+ }
+ if (my_num_tasks == 0) {
+ return errors::InvalidArgument("Job name \"", options->job_name,
+ "\" does not appear in the cluster spec");
+ }
+ if (options->task_index >= my_num_tasks) {
+ return errors::InvalidArgument("Task index ", options->task_index,
+ " is invalid (job \"", options->job_name,
+ "\" contains ", my_num_tasks, " tasks");
+ }
+ return Status::OK();
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::GrpcServerOptions options;
+ tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
+ if (!s.ok()) {
+ std::cerr << "ERROR: " << s.error_message() << std::endl;
+ std::cerr << "Usage: " << argv[0]
+ << " --cluster_spec=SPEC --job_name=NAME --task_id=ID"
+ << std::endl;
+ std::cerr << "Where:" << std::endl;
+ std::cerr << " SPEC is <JOB>(,<JOB>)*" << std::endl;
+ std::cerr << " JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*" << std::endl;
+ std::cerr << " NAME is a valid job name ([a-z][0-9a-z]*)" << std::endl;
+ std::cerr << " HOST is a hostname or IP address" << std::endl;
+ std::cerr << " PORT is a port number" << std::endl;
+ return -1;
+ }
+ tensorflow::StartTensorFlowServer(options);
+}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc
new file mode 100644
index 0000000000..ee5973c83a
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_tensorflow_server_lib.cc
@@ -0,0 +1,123 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "external/grpc/include/grpc++/security/credentials.h"
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+#include "tensorflow/core/distributed_runtime/master_env.h"
+#include "tensorflow/core/distributed_runtime/master_session.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_master_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+// This binary starts a TensorFlow server (master and worker) for test purposes.
+namespace tensorflow {
+
+struct GrpcTaskOptions {
+ // This process belongs to the "job_name".
+ string job_name;
+
+ // This process is the task-th task within the replica. 0th, 1st,
+ // 2nd, etc.
+ int32 task = 0;
+
+ // Specification of peers.
+ GrpcChannelSpec channel_spec;
+
+ SessionOptions default_session_options;
+};
+
+Status StartTensorFlowServer(const TaskOptions& task_options) {
+ thread::ThreadPool* thread_pool =
+ new thread::ThreadPool(Env::Default(), "server", 1);
+ thread_pool->Schedule([argc, argv, task_options]() {
+ // This process provides both the worker service and the master
+ // service. We let these two services share the same channel cache
+ // (rpc connections) and cpu devices (used by the master as the
+ // client device). These client devices require a worker service
+ // so that remote devices can copy the feeds from the client
+ // device in the master.
+ tensorflow::MasterEnv master_env;
+ string name_prefix =
+ strings::StrCat("/job:", task_optionss.job_name, "/replica:0", "/task:",
+ task_options.task);
+ DeviceFactory::AddDevices(task_options.default_session_options, name_prefix,
+ &master_env.local_devices);
+
+ // Create the DeviceMgr before initializing the RPC layer, because that
+ // needs to know how many devices of each kind exist.
+ WorkerEnv worker_env;
+ worker_env.device_mgr = new DeviceMgr(master_env.local_devices);
+
+ // Finish setting up Env for Worker service.
+ string donotcare;
+ CHECK(DeviceNameUtils::SplitDeviceName(master_env.local_devices[0]->name(),
+ &worker_env.worker_name,
+ &donotcare));
+ worker_env.env = Env::Default();
+
+ GrpcChannelCache* channel_cache =
+ NewGrpcChannelCache(task_options.channel_spec);
+ string server_address = channel_cache->TranslateTask(name_prefix);
+ worker_env.worker_cache = NewGrpcWorkerCache(channel_cache);
+ worker_env.graph_mgr = new GraphMgr(&worker_env);
+ worker_env.rendezvous_mgr = new RpcRendezvousMgr(&worker_env);
+ worker_env.compute_pool = ComputePool(task_options.default_session_options);
+
+ // Finish setting up Env for Master service.
+ master_env.env = Env::Default();
+ master_env.ops = OpRegistry::Global();
+ master_env.worker_cache = worker_env.worker_cache;
+ master_env.master_session_factory = internal::NewMasterSession;
+
+ ::grpc::ServerBuilder builder;
+ builder.AddListeningPort(server_address,
+ ::grpc::InsecureServerCredentials());
+ auto master_service = NewGrpcMasterService(&master_env, &builder);
+ auto worker_service = NewGrpcWorkerService(&worker_env, &builder);
+ // Finally assemble the server.
+ auto server_ = builder.BuildAndStart();
+
+ std::unique_ptr<Thread> master_thread(Env::Default()->StartThread(
+ ThreadOptions(), "master_service_thread",
+ [master_service]() { master_service->HandleRPCsLoop(); }));
+
+ std::unique_ptr<Thread> worker_thread(Env::Default()->StartThread(
+ ThreadOptions(), "worker_service_thread",
+ [worker_service]() { worker_service->HandleRPCsLoop(); }));
+ });
+
+ // The ThreadPool destructor waits until all work is done (i.e. forever).
+ delete thread_pool;
+ return Status::OK();
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
new file mode 100644
index 0000000000..85b3dae56f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.cc
@@ -0,0 +1,84 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_testlib.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_session.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace test {
+
+Status TestCluster::MakeTestCluster(const SessionOptions& options, int n,
+ std::unique_ptr<TestCluster>* out_cluster) {
+ CHECK_GE(n, 1);
+ std::unique_ptr<TestCluster> ret(new TestCluster);
+
+ ret->targets_.resize(n);
+
+ std::vector<int> port(n);
+ for (int i = 0; i < n; ++i) {
+ port[i] = testing::PickUnusedPortOrDie();
+ ret->targets_[i] = strings::StrCat("localhost:", port[i]);
+ }
+
+ const string tf_jobs = strings::StrCat("--tf_jobs=localhost|",
+ str_util::Join(ret->targets_, ";"));
+
+ int num_cpus = 1;
+ int num_gpus = 0;
+ auto iter = options.config.device_count().find("CPU");
+ if (iter != options.config.device_count().end()) {
+ num_cpus = iter->second;
+ }
+ iter = options.config.device_count().find("GPU");
+ if (iter != options.config.device_count().end()) {
+ num_gpus = iter->second;
+ }
+
+ for (int i = 0; i < n; ++i) {
+ const std::vector<string> argv(
+ {strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/distributed_runtime/rpc/grpc_testlib_server"),
+ /* see grpc_testlib_server.cc for flags */
+ tf_jobs, "--tf_job=localhost", strings::StrCat("--tf_task=", i),
+ strings::StrCat("--num_cpus=", num_cpus),
+ strings::StrCat("--num_gpus=", num_gpus)});
+ ret->subprocesses_.emplace_back(testing::CreateSubProcess(argv));
+ bool success = ret->subprocesses_[i]->Start();
+ if (!success) {
+ return errors::Internal("Could not start subprocess");
+ }
+ }
+
+ SessionOptions options_copy(options);
+ options_copy.target = strings::StrCat("grpc://", ret->targets_[0]);
+ std::unique_ptr<GrpcSession> session(new GrpcSession(options_copy));
+ std::vector<DeviceAttributes> device_attributes;
+ ret->devices_ = session->ListDevices();
+
+ *out_cluster = std::move(ret);
+ return Status::OK();
+}
+
+TestCluster::~TestCluster() {
+ for (auto& subprocess : subprocesses_) {
+ subprocess->Kill(9);
+ }
+}
+
+} // end namespace test
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
new file mode 100644
index 0000000000..7460c1c9b4
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
@@ -0,0 +1,73 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+
+class Device;
+
+namespace test {
+
+// Provides a handle to a set of TensorFlow servers (masters and
+// workers) for testing purposes.
+//
+// This class currently runs the servers in separate processes; the
+// lifetime of this object is coterminous with the lifetimes of those
+// processes.
+class TestCluster {
+ public:
+ // Creates a new test cluster based on the given `options` (which
+ // configure the number of devices of each type) and a count of
+ // processes `n`. On success, the test cluster is stored in
+ // *out_cluster, and this function returns OK. Otherwise an error is
+ // returned.
+ static Status MakeTestCluster(const SessionOptions& options, int n,
+ std::unique_ptr<TestCluster>* out_cluster);
+ ~TestCluster();
+
+ // Returns a vector of string "<hostname>:<port>" pairs that may be
+ // used as targets to construct a GrpcSession.
+ const std::vector<string>& targets() const { return targets_; }
+
+ // Returns a vector of devices available in this test cluster.
+ const std::vector<DeviceAttributes>& devices() const { return devices_; }
+
+ private:
+ TestCluster() = default;
+
+ std::vector<std::unique_ptr<testing::SubProcess>> subprocesses_;
+ std::vector<string> targets_;
+ std::vector<DeviceAttributes> devices_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(TestCluster);
+};
+
+} // end namespace test
+} // end namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc
new file mode 100644
index 0000000000..e2518f8fce
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_ops.cc
@@ -0,0 +1,91 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+namespace test {
+
+// ErrorOp::Compute returns an error.
+REGISTER_OP("Error")
+ .Input("in: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Attr("message: string");
+class ErrorOp : public OpKernel {
+ public:
+ explicit ErrorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("message", &errmsg_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ ctx->SetStatus(errors::Internal(errmsg_));
+ }
+
+ private:
+ string errmsg_;
+};
+REGISTER_KERNEL_BUILDER(Name("Error").Device(DEVICE_CPU), ErrorOp);
+
+REGISTER_OP("InvalidRefType")
+ .Output("out: Ref(TIn)")
+ .Attr("TIn: type")
+ .Attr("TOut: type");
+class InvalidRefType : public OpKernel {
+ public:
+ explicit InvalidRefType(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("TOut", &dtout_));
+ output_ = Tensor(dtout_, TensorShape({}));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ ctx->set_output_ref(0, &mu_, &output_);
+ }
+
+ private:
+ DataType dtout_;
+ mutex mu_;
+ Tensor output_;
+};
+REGISTER_KERNEL_BUILDER(Name("InvalidRefType").Device(DEVICE_CPU),
+ InvalidRefType);
+
+// DelayOp::AsyncCompute sleeps for "micros"-econd and then returns
+// its input.
+REGISTER_OP("Delay")
+ .Input("in: T")
+ .Output("out: T")
+ .Attr("T: type")
+ .Attr("micros: int");
+class DelayOp : public AsyncOpKernel {
+ public:
+ explicit DelayOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("micros", &micros_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ctx->set_output(0, ctx->input(0));
+ ctx->env()->SchedClosureAfter(micros_, done);
+ }
+
+ private:
+ int64 micros_;
+};
+REGISTER_KERNEL_BUILDER(Name("Delay").Device(DEVICE_CPU), DelayOp);
+
+} // namespace test
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
new file mode 100644
index 0000000000..62c88daa17
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib_server.cc
@@ -0,0 +1,92 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "external/grpc/include/grpc++/security/credentials.h"
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/command_line_flags.h"
+
+// This binary starts a TensorFlow server (master and worker) for test purposes.
+namespace tensorflow {
+namespace {
+
+Status ParseFlagsForTask(int argc, char* argv[], GrpcServerOptions* options) {
+ string job_spec;
+ int num_cpus = 1;
+ int num_gpus = 0;
+ const bool parse_result =
+ ParseFlags(&argc, argv, {Flag("tf_jobs", &job_spec), //
+ Flag("tf_job", &options->job_name), //
+ Flag("tf_task", &options->task_index), //
+ Flag("num_cpus", &num_cpus), //
+ Flag("num_gpus", &num_gpus)});
+ if (!parse_result) {
+ return errors::InvalidArgument("Error parsing command-line flags");
+ }
+
+ uint32 my_tasks_per_replica = 0;
+ for (const string& job_str : str_util::Split(job_spec, ',')) {
+ // Split each entry in the flag into 3 pieces, separated by "|".
+ const std::vector<string> job_pieces = str_util::Split(job_str, '|');
+ CHECK_EQ(2, job_pieces.size()) << job_str;
+ const string& job = job_pieces[0];
+ // Does a bit more validation of the tasks_per_replica.
+ const StringPiece spec = job_pieces[1];
+ // job_str is of form <job_name>|<host_ports>.
+ const std::vector<string> host_ports = str_util::Split(spec, ';');
+ uint32 tasks_per_replica = host_ports.size();
+ if (job == options->job_name) {
+ my_tasks_per_replica = tasks_per_replica;
+ }
+ TF_RETURN_IF_ERROR(options->channel_spec.AddHostPortsJob(
+ job, host_ports, tasks_per_replica));
+ LOG(INFO) << "Peer " << job << " " << tasks_per_replica << " {"
+ << str_util::Join(host_ports, ", ") << "}";
+ }
+ if (my_tasks_per_replica == 0) {
+ return errors::InvalidArgument("Invalid job specification");
+ }
+
+ (*options->default_session_options.config.mutable_device_count())["CPU"] =
+ num_cpus;
+ (*options->default_session_options.config.mutable_device_count())["GPU"] =
+ num_gpus;
+ return Status::OK();
+}
+
+} // namespace
+} // namespace tensorflow
+
+int main(int argc, char* argv[]) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::GrpcServerOptions options;
+ tensorflow::Status s = tensorflow::ParseFlagsForTask(argc, argv, &options);
+ if (!s.ok()) {
+ LOG(ERROR) << "Could not parse flags: " << s.error_message();
+ return -1;
+ }
+ tensorflow::StartTensorFlowServer(options);
+ // NOTE(mrry): Unreachable code.
+ return 0;
+}
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_util.h b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
new file mode 100644
index 0000000000..fc4b699e2a
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_util.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
+
+#include <memory>
+
+#include "external/grpc/include/grpc++/grpc++.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+inline Status FromGrpcStatus(const ::grpc::Status& s) {
+ if (s.ok()) {
+ return Status::OK();
+ } else {
+ return Status(static_cast<tensorflow::error::Code>(s.error_code()),
+ s.error_message());
+ }
+}
+
+inline ::grpc::Status ToGrpcStatus(const ::tensorflow::Status& s) {
+ if (s.ok()) {
+ return ::grpc::Status::OK;
+ } else {
+ return ::grpc::Status(static_cast<::grpc::StatusCode>(s.code()),
+ s.error_message());
+ }
+}
+
+typedef std::shared_ptr<::grpc::Channel> SharedGrpcChannelPtr;
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_UTIL_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
new file mode 100644
index 0000000000..8658b2f31e
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -0,0 +1,85 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_client_cq_tag.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
+#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+class GrpcWorkerCache : public WorkerCachePartial {
+ public:
+ explicit GrpcWorkerCache(GrpcChannelCache* channel_cache)
+ : channel_cache_(channel_cache) {
+ // TODO(mrry): Investigate possible performance improvements by
+ // replacing this thread with a threadpool.
+ polling_thread_ = Env::Default()->StartThread(
+ ThreadOptions(), "grpc_worker_cache", [this]() {
+ void* tag;
+ bool ok;
+ while (completion_queue_.Next(&tag, &ok)) {
+ GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
+ callback_tag->OnCompleted(ok);
+ delete callback_tag;
+ }
+ });
+ }
+
+ // Explicit destructor to control destruction order.
+ ~GrpcWorkerCache() override {
+ completion_queue_.Shutdown();
+ delete polling_thread_; // Blocks until thread exits.
+ delete channel_cache_;
+ }
+
+ void ListWorkers(std::vector<string>* workers) override {
+ channel_cache_->ListWorkers(workers);
+ }
+
+ WorkerInterface* CreateWorker(const string& target) override {
+ SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
+ CHECK(channel) << "Channel was null";
+ if (!channel) return nullptr;
+ WorkerInterface* ret =
+ NewGrpcRemoteWorker(channel, &completion_queue_, &logger_);
+ return ret;
+ }
+
+ void SetLogging(bool v) override { logger_.SetLogging(v); }
+
+ void ClearLogs() override { logger_.ClearLogs(); }
+
+ bool RetrieveLogs(int64 step_id, StepStats* ss) override {
+ return logger_.RetrieveLogs(step_id, ss);
+ }
+
+ private:
+ GrpcChannelCache* channel_cache_; // Owned.
+ ::grpc::CompletionQueue completion_queue_;
+ Thread* polling_thread_; // Owned.
+ WorkerCacheLogger logger_;
+};
+
+WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc) {
+ return new GrpcWorkerCache(cc);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
new file mode 100644
index 0000000000..9332d38922
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h
@@ -0,0 +1,28 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+
+namespace tensorflow {
+
+// The returned WorkerCacheInterface object takes the ownership of "cc".
+WorkerCacheInterface* NewGrpcWorkerCache(GrpcChannelCache* cc);
+
+} // namespace tensorflow
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_CACHE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
new file mode 100644
index 0000000000..ed69f4beb9
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -0,0 +1,415 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
+
+#include <deque>
+
+#include "external/grpc/include/grpc++/server_builder.h"
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
+#include "tensorflow/core/common_runtime/gpu/process_state.h"
+#include "tensorflow/core/common_runtime/gpu_device_context.h"
+#include "tensorflow/core/common_runtime/local_device.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/distributed_runtime/graph_mgr.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/protobuf/worker_service.grpc.pb.h"
+#include "tensorflow/core/protobuf/worker_service.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+static Tensor empty_tensor(DT_FLOAT);
+
+class GrpcWorkerService : public AsyncServiceInterface {
+ public:
+ GrpcWorkerService(WorkerEnv* env, ::grpc::ServerBuilder* builder)
+ : env_(env), cancellation_manager_(new CancellationManager) {
+ builder->RegisterService(&worker_service_);
+ cq_ = builder->AddCompletionQueue().release();
+ }
+
+ ~GrpcWorkerService() { delete cq_; }
+
+// This macro creates a new request for the given RPC method name
+// (e.g., `ENQUEUE_REQUEST(GetStatus);`), and enqueues it on
+// `this->cq_`.
+//
+// This macro is invoked one or more times for each RPC method to
+// ensure that there are sufficient completion queue entries to
+// handle incoming requests without blocking.
+//
+// The implementation of the request handler for each RPC method
+// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
+// to keep accepting new requests.
+#define ENQUEUE_REQUEST(method) \
+ do { \
+ Call<GrpcWorkerService, grpc::WorkerService::AsyncService, \
+ method##Request, method##Response>:: \
+ EnqueueRequest(&worker_service_, cq_, \
+ &grpc::WorkerService::AsyncService::Request##method, \
+ &GrpcWorkerService::method##Handler); \
+ } while (0)
+
+ // This method blocks forever handling requests from the completion queue.
+ void HandleRPCsLoop() {
+ // TODO(mrry): This may require performance engineering. We can
+ // add more threads to service the completion queue, and add more
+ // of various request types if they are short and frequent.
+ // Currently we allow unbounded numbers of pending calls for each
+ // method, by re-enqueuing a request before the previous one
+ // completes, and we may decide to bound some of the request
+ // types.
+ ENQUEUE_REQUEST(GetStatus);
+ ENQUEUE_REQUEST(CleanupAll);
+ ENQUEUE_REQUEST(RegisterGraph);
+ ENQUEUE_REQUEST(DeregisterGraph);
+
+ // TODO(mrry): Consider enqueuing more of these request types.
+ ENQUEUE_REQUEST(RecvTensor);
+ ENQUEUE_REQUEST(RunGraph);
+
+ ENQUEUE_REQUEST(CleanupGraph);
+ ENQUEUE_REQUEST(Logging);
+ ENQUEUE_REQUEST(Tracing);
+
+ void* tag;
+ bool ok;
+ while (cq_->Next(&tag, &ok)) {
+ UntypedCall<GrpcWorkerService>::Tag* callback_tag =
+ static_cast<UntypedCall<GrpcWorkerService>::Tag*>(tag);
+ callback_tag->OnCompleted(this, ok);
+ delete callback_tag;
+ }
+ }
+
+ private:
+ WorkerEnv* env_; // Not owned.
+ ::grpc::ServerCompletionQueue* cq_; // Owned.
+
+ grpc::WorkerService::AsyncService worker_service_;
+
+ mutex mu_;
+ CancellationManager* cancellation_manager_ GUARDED_BY(mu_);
+
+ // The following section contains one request handler method per
+ // RPC. The The `FooHandler` method is called (indirectly) by
+ // `HandleRPCsLoop()` when the next Foo RPC is received. Each
+ // `FooHandler` call schedules a closure on `env_->compute_pool`,
+ // and is responsible for requesting the next Foo call by calling
+ // `ENQUEUE_REQUEST(Foo)`.
+
+ template <class RequestMessage, class ResponseMessage>
+ using WorkerCall = Call<GrpcWorkerService, grpc::WorkerService::AsyncService,
+ RequestMessage, ResponseMessage>;
+
+ void GetStatusHandler(WorkerCall<GetStatusRequest, GetStatusResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ DeviceMgr* dm = env_->device_mgr;
+ std::vector<DeviceAttributes> devices;
+ dm->ListDeviceAttributes(&devices);
+ call->response.mutable_device_attributes()->Reserve(devices.size());
+ for (size_t i = 0; i < devices.size(); i++) {
+ call->response.add_device_attributes()->Swap(&devices[i]);
+ }
+ call->SendResponse(::grpc::Status::OK);
+ });
+ ENQUEUE_REQUEST(GetStatus);
+ }
+
+ void CleanupAllHandler(
+ WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ std::vector<string> containers;
+ for (const auto& c : call->request.container()) containers.push_back(c);
+ env_->device_mgr->ClearContainers(containers);
+ call->SendResponse(::grpc::Status::OK);
+ });
+ ENQUEUE_REQUEST(CleanupAll);
+ }
+
+ void RegisterGraphHandler(
+ WorkerCall<RegisterGraphRequest, RegisterGraphResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ Status s = env_->graph_mgr->Register(
+ call->request.session_handle(), call->request.graph_def(),
+ call->request.graph_options(), call->response.mutable_graph_handle());
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(RegisterGraph);
+ }
+
+ void DeregisterGraphHandler(
+ WorkerCall<DeregisterGraphRequest, DeregisterGraphResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(DeregisterGraph);
+ }
+
+ void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
+ ENQUEUE_REQUEST(RunGraph);
+ }
+
+ void RecvTensorHandler(
+ WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
+ ENQUEUE_REQUEST(RecvTensor);
+ }
+
+ void CleanupGraphHandler(
+ WorkerCall<CleanupGraphRequest, CleanupGraphResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ const int64 step_id = call->request.step_id();
+ env_->rendezvous_mgr->Cleanup(step_id);
+ call->SendResponse(::grpc::Status::OK);
+ });
+ ENQUEUE_REQUEST(CleanupGraph);
+ }
+
+ void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
+ env_->compute_pool->Schedule([this, call]() {
+ Status s = DoLogging(call);
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(Logging);
+ }
+
+ void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
+ SchedClosure([this, call]() {
+ Status s = DoTracing(call);
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ ENQUEUE_REQUEST(Tracing);
+ }
+#undef ENQUEUE_REQUEST
+
+ private:
+ // The following section contains the implementation of RunGraph()
+ // RecvTensor(), Logging(), and Tracing(), which are the four
+ // non-trivial and potentially long-running RPCs performed by a
+ // TensorFlow worker.
+
+ void AbortStep(int64 step_id) {
+ Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
+ SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
+ // Delay a bit before aborting the step. This way, the root
+ // cause may return first back to the client instead of this
+ // cancellation generated abort error.
+ rendez->StartAbort(errors::Aborted("Step ", step_id));
+ rendez->Unref();
+ });
+ }
+
+ Status PrepareRunGraph(const RunGraphRequest& req, GraphMgr::NamedTensors* in,
+ GraphMgr::NamedTensors* out) {
+ if (req.send_size() > 0) {
+ // TODO(zhifengc): Let the caller decide on which device to
+ // allocate the tensor.
+ Device* cpu_dev = nullptr;
+ TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice("CPU:0", &cpu_dev));
+ AllocatorAttributes alloc_attrs;
+ Tensor val;
+ for (const NamedTensor& entry : req.send()) {
+ TF_RETURN_IF_ERROR(
+ cpu_dev->MakeTensorFromProto(entry.val(), alloc_attrs, &val));
+ in->insert({entry.key(), val});
+ }
+ }
+ for (const string& key : req.recv_key()) {
+ out->insert({key, empty_tensor});
+ }
+ return Status::OK();
+ }
+
+ void DoRunGraph(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
+ const int64 step_id = call->request.step_id();
+ TRACEPRINTF("RunGraph: %lld", step_id);
+ GraphMgr::NamedTensors in;
+ GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
+ Status s = PrepareRunGraph(call->request, &in, out);
+ if (!s.ok()) {
+ delete out;
+ call->SendResponse(ToGrpcStatus(s));
+ return;
+ }
+ StepStatsCollector* collector = nullptr;
+ // TODO(mrry): Collect results from a profiler if available.
+ CancellationManager* cm = new CancellationManager;
+ call->SetCancelCallback([this, cm, step_id]() {
+ cm->StartCancel();
+ AbortStep(step_id);
+ });
+ CancellationToken token;
+ {
+ mutex_lock l(mu_);
+ token = cancellation_manager_->get_cancellation_token();
+ cancellation_manager_->RegisterCallback(token,
+ [cm]() { cm->StartCancel(); });
+ }
+ env_->graph_mgr->ExecuteAsync(
+ call->request.graph_handle(), step_id, call->request.exec_opts(),
+ collector, cm, in, out, [this, call, cm, out, token](Status s) {
+ call->ClearCancelCallback();
+ {
+ mutex_lock l(mu_);
+ cancellation_manager_->DeregisterCallback(token);
+ }
+ delete cm;
+
+ if (s.ok()) {
+ for (const auto& p : *out) {
+ const string& key = p.first;
+ const Tensor& val = p.second;
+ auto* recv = call->response.add_recv();
+ recv->set_key(key);
+ // TODO(zhifengc): Deal with gpu -> cpu copy.
+ TensorProto* proto = recv->mutable_val();
+ val.AsProtoField(proto);
+ }
+ }
+ delete out;
+ call->SendResponse(ToGrpcStatus(s));
+ });
+ }
+
+ // Helper for RecvTensor. Validates "key" and returns the source
+ // device in "*src_dev".
+ Status PrepareRecvTensor(const string& key, Device** src_dev) {
+ // Validate the key.
+ Rendezvous::ParsedKey parsed;
+ TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
+
+ // Figures out which device the tensor is hosted on.
+ TF_RETURN_IF_ERROR(
+ env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
+
+ // Does the device have the right incarnation number we expect?
+ if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {
+ return errors::Aborted(
+ "RecvTensor expects a different device incarnation: ",
+ parsed.src_incarnation, " vs. ",
+ (*src_dev)->attributes().incarnation(),
+ ". Your worker job was probably restarted. Check your "
+ "worker job for the reason why it was restarted.");
+ }
+
+ return Status::OK();
+ }
+
+ void DoRecvTensor(WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
+ const int64 step_id = call->request.step_id();
+ const string& key = call->request.rendezvous_key();
+ TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
+ Device* src_dev = nullptr;
+ Status s = PrepareRecvTensor(key, &src_dev);
+ if (!s.ok()) {
+ call->SendResponse(ToGrpcStatus(s));
+ return;
+ }
+
+ // Request the tensor associated with the rendezvous key. Any time
+ // while waiting for the tensor to be produced, up until the start
+ // of execution of the callback lambda body below, an RPC
+ // cancellation should abort the rendezvous.
+ call->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
+ env_->rendezvous_mgr->RecvLocalAsync(
+ step_id, key,
+ [this, call, src_dev](const Status& status,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& val, const bool is_dead) {
+ call->ClearCancelCallback();
+ Status s = status;
+ if (s.ok()) {
+ // DMA can only be used for Tensors that do not fall into
+ // the following three odd edge cases: 1) a zero-size
+ // buffer, 2) a dead tensor which has an uninit value, and
+ // 3) the tensor has the on_host allocation attribute,
+ // i.e. it's in CPU RAM *independent of its assigned
+ // device type*.
+ // const size_t bytes = is_dead ? 0 : val.TotalBytes();
+ const bool on_host = send_args.alloc_attrs.on_host();
+ const DeviceContext* send_dev_context = send_args.device_context;
+ call->response.set_is_dead(is_dead);
+ StatusCallback response_ready = [call](const Status& s) {
+ // The value is now ready to be returned on the wire.
+ call->response.set_send_start_micros(Env::Default()->NowMicros());
+ call->SendResponse(ToGrpcStatus(s));
+ };
+ {
+ // Non-DMA cases.
+ if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
+ CHECK(send_dev_context)
+ << "send dev name: " << src_dev->name()
+ << " gpu_info: " << src_dev->tensorflow_gpu_device_info();
+ // "val" is on a GPU. Uses GPUUtil to fill the response proto.
+ GPUUtil::SetProtoFromGPU(val, src_dev, send_dev_context,
+ call->response.mutable_tensor(),
+ is_dead, response_ready);
+ } else {
+ // "val" is in CPU memory.
+ TensorProto* proto = call->response.mutable_tensor();
+ val.AsProtoTensorContent(proto);
+ response_ready(Status::OK());
+ }
+ }
+ } else {
+ // !s.ok()
+ call->SendResponse(ToGrpcStatus(s));
+ }
+ });
+ }
+
+ Status DoLogging(WorkerCall<LoggingRequest, LoggingResponse>* call) {
+ // TODO(mrry): Platform-specific tracing support.
+ return errors::Unimplemented("Logging");
+ }
+
+ Status DoTracing(WorkerCall<TracingRequest, TracingResponse>* call) {
+ // TODO(mrry): Platform-specific tracing support.
+ return errors::Unimplemented("Tracing");
+ }
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
+};
+
+} // namespace
+
+AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
+ ::grpc::ServerBuilder* builder) {
+ return new GrpcWorkerService(env, builder);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
new file mode 100644
index 0000000000..4b46aed835
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
@@ -0,0 +1,34 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+
+namespace grpc {
+class ServerBuilder;
+} // namespace grpc
+
+namespace tensorflow {
+
+class AsyncServiceInterface;
+class WorkerEnv;
+
+// Returns an implementation of WorkerService rpc service.
+AsyncServiceInterface* NewGrpcWorkerService(WorkerEnv* env,
+ ::grpc::ServerBuilder* builder);
+
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
new file mode 100644
index 0000000000..6b69e37e15
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.cc
@@ -0,0 +1,196 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+
+#include <unordered_set>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+class RpcRemoteRendezvous : public BaseRemoteRendezvous {
+ public:
+ RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
+ : BaseRemoteRendezvous(env, step_id, false) {}
+
+ protected:
+ void RecvFromRemoteAsync(const string& key,
+ const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& args,
+ DoneCallback done) override;
+
+ private:
+ ~RpcRemoteRendezvous() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
+};
+
+// Used only to retrieve tensors from remote processes.
+class RpcRecvTensorCall : public BaseRecvTensorCall {
+ public:
+ RpcRecvTensorCall(WorkerCacheInterface* wc, WorkerInterface* wi,
+ int64 step_id, const string& key,
+ const string& remote_dev, Allocator* allocator,
+ Device* dst_device)
+ : wi_(wi),
+ wc_(wc),
+ remote_dev_(remote_dev),
+ allocator_(allocator),
+ dst_(dst_device) {
+ req_.set_step_id(step_id);
+ req_.set_rendezvous_key(key);
+ }
+
+ ~RpcRecvTensorCall() override { delete wi_; }
+
+ void Start(std::function<void()> recv_done) override {
+ StartRTCall(recv_done);
+ }
+
+ void StartAbort(const Status& s) override {
+ {
+ mutex_lock l(mu_);
+ status_.Update(s);
+ }
+ opts_.StartCancel();
+ }
+
+ Status status() const override {
+ mutex_lock l(mu_);
+ return status_;
+ }
+
+ const TensorProto& tensor_proto() const { return resp_.tensor(); }
+
+ const RecvTensorResponse& response() const { return resp_; }
+
+ bool is_dead() const { return resp_.is_dead(); }
+
+ private:
+ // Start the main RecvTensor call, checking for an async abort.
+ void StartRTCall(std::function<void()> recv_done) {
+ wi_->RecvTensorAsync(&opts_, &req_, &resp_,
+ nullptr /* TensorBufAllocator */,
+ // done callback
+ [this, recv_done](const Status& s) {
+ {
+ mutex_lock l(mu_);
+ status_.Update(s);
+ }
+ recv_done();
+ });
+ }
+
+ WorkerInterface* wi_; // Owned.
+ WorkerCacheInterface* wc_; // Not owned.
+ string remote_dev_;
+ Allocator* allocator_;
+ Device* dst_;
+ CallOptions opts_;
+ RecvTensorRequest req_;
+ RecvTensorResponse resp_;
+
+ mutable mutex mu_;
+ Status status_ GUARDED_BY(mu_);
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
+};
+
+
+void RpcRemoteRendezvous::RecvFromRemoteAsync(
+ const string& key, const Rendezvous::ParsedKey& parsed,
+ const Rendezvous::Args& recv_args, DoneCallback done) {
+ Status s;
+
+ // key.src_device identifies a remote device.
+ string src_worker;
+ string src_rel_device;
+ if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_worker,
+ &src_rel_device)) {
+ s = errors::Internal(parsed.src_device,
+ " is invalid remote source device.");
+ }
+ WorkerCacheInterface* worker_cache = env_->worker_cache;
+ if (s.ok() && worker_cache == nullptr) {
+ s = errors::Internal("No remote worker cache available.");
+ }
+ WorkerInterface* rwi = env_->worker_cache->CreateWorker(src_worker);
+ if (s.ok() && rwi == nullptr) {
+ s = errors::Internal("No worker known as ", src_worker);
+ }
+
+ Device* dst_device;
+ if (s.ok()) {
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ }
+ if (!s.ok()) {
+ done(s, Args(), recv_args, Tensor{}, false);
+ return;
+ }
+ Allocator* allocator = dst_device->GetAllocator(recv_args.alloc_attrs);
+
+ // Prepare a RecvTensor call that can handle being aborted.
+ RpcRecvTensorCall* call =
+ new RpcRecvTensorCall(worker_cache, rwi, step_id_, key,
+ parsed.src_device, allocator, dst_device);
+
+ // Record "call" in active_ so that it can be aborted cleanly.
+ RegisterCall(call);
+
+ // Start "call".
+ call->Start([this, call, parsed, recv_args, done]() {
+ // Removes "call" from active_. Prevent StartAbort().
+ DeregisterCall(call);
+ // If StartAbort was called prior to DeregisterCall, then the
+ // current status should be bad.
+ Status s = call->status();
+ Tensor val;
+ if (s.ok()) {
+ Device* dst_device;
+ s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
+ if (s.ok()) {
+ s = dst_device->MakeTensorFromProto(call->tensor_proto(),
+ recv_args.alloc_attrs, &val);
+ }
+ }
+ done(s, Args(), recv_args, val, call->is_dead());
+ delete call;
+ });
+}
+
+} // namespace
+
+BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
+ const WorkerEnv* worker_env) {
+ return new RpcRemoteRendezvous(worker_env, step_id);
+}
+
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
new file mode 100644
index 0000000000..65b21b425c
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h
@@ -0,0 +1,57 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
+
+#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
+#include "tensorflow/core/distributed_runtime/worker_env.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// RendezvousMgr keeps track of a set of local rendezvous instances.
+// All tensors sent by this worker are buffered in a RendezvousMgr
+// until the tensor is received. Each global unique "step_id"
+// corresponds to one local rendezvous instance managed by a
+// RendezvousMgr.
+//
+// E.g.,
+// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
+// fork execution of an graph executor using "rendez" on thread 1;
+// fork execution of another graph executor using "rendez" on thread 2;
+// ...
+// join threads 1 and 2;
+//
+// In the example above, execution in thread 1 and 2 communicates with
+// each other by send/recv operations through the "rend".
+//
+// Tensors sent and recved through rendezvous managed by this
+// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
+class RpcRendezvousMgr : public BaseRendezvousMgr {
+ public:
+ explicit RpcRendezvousMgr(const WorkerEnv* env) : BaseRendezvousMgr(env) {}
+
+ protected:
+ BaseRemoteRendezvous* Create(int64 step_id,
+ const WorkerEnv* worker_env) override;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
new file mode 100644
index 0000000000..0f855e8f28
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -0,0 +1,172 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
+
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+// string -> Tensor<string>
+Tensor V(const string& content) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = content;
+ return tensor;
+}
+
+// Tensor<string> -> string
+string V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_STRING);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<string>()();
+}
+
+TEST(RpcRendezvousMgrTest, LocalSendRecv) {
+ WorkerEnv env;
+ env.env = Env::Default();
+ env.worker_name = "/job:mnist/replica:1/task:2";
+ RpcRendezvousMgr rmgr(&env);
+ const int64 step_id = 123;
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/cpu:0", 7890,
+ "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
+ {
+ Rendezvous* rendez = rmgr.Find(step_id);
+ core::ScopedUnref unref(rendez);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
+ }
+ {
+ Tensor val(DT_FLOAT);
+ bool val_dead = false;
+ TF_ASSERT_OK(rmgr.RecvLocal(step_id, key, &val, &val_dead));
+ EXPECT_EQ(V(val), "peach");
+ }
+ rmgr.Cleanup(step_id);
+}
+
+TEST(RpcRendezvousMgrTest, LocalAbort) {
+ WorkerEnv env;
+ env.env = Env::Default();
+ env.worker_name = "/job:mnist/replica:1/task:2";
+ RpcRendezvousMgr rmgr(&env);
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/cpu:0", 7890,
+ "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
+ { // Explicit Abort().
+ const int64 step_id = 123;
+ Rendezvous* rendez = rmgr.Find(step_id);
+ core::ScopedUnref unref(rendez);
+ SchedClosure([env, rendez]() {
+ env.env->SleepForMicroseconds(100 * 1000);
+ rendez->StartAbort(errors::Aborted(""));
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
+ }
+ { // Cleanup causes Abort().
+ const int64 step_id = 321;
+ Rendezvous* rendez = rmgr.Find(step_id);
+ core::ScopedUnref unref(rendez);
+ SchedClosure([env, &rmgr, step_id]() {
+ env.env->SleepForMicroseconds(100 * 1000);
+ rmgr.Cleanup(step_id);
+ });
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ Rendezvous::Args args;
+ EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
+ }
+}
+
+TEST(RpcRendezvousMgrTest, CleanupAll) {
+ WorkerEnv env;
+ env.env = Env::Default();
+ env.worker_name = "/job:mnist/replica:1/task:2";
+ RpcRendezvousMgr rmgr(&env);
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/cpu:0", 7890,
+ "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
+ {
+ const int64 step_id = 123;
+ Rendezvous* rendez = rmgr.Find(step_id);
+ core::ScopedUnref unref(rendez);
+ Rendezvous::Args args;
+ TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
+ rmgr.CleanupAll();
+ Tensor val(DT_STRING);
+ bool val_dead = false;
+ EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
+ }
+}
+
+class DummyDeviceContext : public DeviceContext {
+ public:
+ explicit DummyDeviceContext(int stream_id) : stream_id_(stream_id) {}
+ ~DummyDeviceContext() override {}
+ int stream_id() const { return stream_id_; }
+
+ private:
+ const int stream_id_;
+};
+
+TEST(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
+ DummyDeviceContext* dc = new DummyDeviceContext(123);
+
+ WorkerEnv env;
+ env.env = Env::Default();
+ env.worker_name = "/job:mnist/replica:1/task:2";
+ RpcRendezvousMgr rmgr(&env);
+ const int64 step_id = 123;
+ const string key = Rendezvous::CreateKey(
+ "/job:mnist/replica:1/task:2/cpu:0", 7890,
+ "/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0));
+ {
+ Rendezvous* rendez = rmgr.Find(step_id);
+ core::ScopedUnref unref(rendez);
+ Rendezvous::Args args;
+ args.device_context = dc;
+ TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
+ }
+ {
+ Notification n;
+ rmgr.RecvLocalAsync(
+ step_id, key, [&n](const Status& s, const Rendezvous::Args send_args,
+ const Rendezvous::Args recv_args, const Tensor& val,
+ bool is_dead) {
+ auto send_dev_context =
+ static_cast<DummyDeviceContext*>(send_args.device_context);
+ CHECK_EQ(123, send_dev_context->stream_id());
+ CHECK_EQ(V(val), "peach");
+ n.Notify();
+ });
+ n.WaitForNotification();
+ }
+ rmgr.Cleanup(step_id);
+ dc->Unref();
+}
+
+// NOTE: Remote Send/Recv is better tested in worker_test.cc
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc
new file mode 100644
index 0000000000..94714f4709
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc
@@ -0,0 +1,309 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/simple_graph_execution_state.h"
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/simple_placer.h"
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/dot.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+#include "tensorflow/core/util/util.h"
+
+namespace tensorflow {
+
+string BuildGraphOptions::DebugString() const {
+ string rv = "Feed endpoints: ";
+ for (auto& s : feed_endpoints) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ strings::StrAppend(&rv, "\nFetch endpoints: ");
+ for (auto& s : fetch_endpoints) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ strings::StrAppend(&rv, "\nTarget nodes: ");
+ for (auto& s : target_nodes) {
+ strings::StrAppend(&rv, s, ", ");
+ }
+ return rv;
+}
+
+SimpleGraphExecutionState::SimpleGraphExecutionState(
+ const OpRegistryInterface* ops,
+ const SimpleGraphExecutionStateOptions& options)
+ : ops_(ops),
+ device_set_(options.device_set),
+ session_options_(options.session_options),
+ base_(nullptr),
+ placed_(nullptr) {
+ // TODO(mrry): Publish placement visualizations or handle the log
+ // placement option.
+}
+
+SimpleGraphExecutionState::~SimpleGraphExecutionState() {
+ mutex_lock l(mu_);
+ delete base_;
+ delete placed_;
+}
+
+Status SimpleGraphExecutionState::Create(GraphDef* graph_def) {
+ if (original_graph_def_.node_size() > 0) {
+ return errors::InvalidArgument(
+ "Cannot call Create on SimpleGraphExecutionState twice");
+ }
+
+ original_graph_def_.Swap(graph_def);
+ VLOG(2) << "Incoming def: " << original_graph_def_.DebugString();
+ return AddDefaultAttrsToGraphDef(&original_graph_def_, *ops_, 0);
+}
+
+Status SimpleGraphExecutionState::Extend(
+ const GraphDef& extension_def, SimpleGraphExecutionState** out) const {
+ std::unordered_set<string> new_names;
+ // 1. Build an index of the new node names.
+ for (const NodeDef& node : extension_def.node()) {
+ new_names.insert(node.name());
+ }
+
+ // 2. Add the non-duplicates from the old graph to the new graph.
+ // Return an error if the same node name appears in both the
+ // old graph and the extension.
+ GraphDef gdef;
+ for (const NodeDef& node : original_graph_def_.node()) {
+ if (new_names.count(node.name()) == 0) {
+ *gdef.add_node() = node;
+ } else {
+ return errors::InvalidArgument(tensorflow::strings::Printf(
+ "GraphDef argument to Extend includes node '%s', which was created "
+ "by a previous call to Create or Extend in this session.",
+ node.name().c_str()));
+ }
+ }
+
+ int old_node_size = gdef.node_size();
+ gdef.mutable_node()->MergeFrom(extension_def.node());
+ TF_RETURN_IF_ERROR(AddDefaultAttrsToGraphDef(&gdef, *ops_, old_node_size));
+
+ // 3. Add the extension.
+ SimpleGraphExecutionStateOptions combined_options;
+ combined_options.device_set = device_set_;
+
+ SimpleGraphExecutionState* new_execution_state =
+ new SimpleGraphExecutionState(ops_, combined_options);
+ Status new_execution_state_status = new_execution_state->Create(&gdef);
+ if (!new_execution_state_status.ok()) {
+ delete new_execution_state;
+ return new_execution_state_status;
+ }
+ *out = new_execution_state;
+
+ // Ensure that any state created in the precursor is accessible in the
+ // new graph.
+ {
+ mutex_lock l(mu_);
+ for (const auto& placement : stateful_placements_) {
+ (*out)->stateful_placements_.insert(placement);
+ }
+ }
+
+ // TODO(mrry): This is likely to be used for non-throughput-sensitive
+ // interactive workloads, but in future we may want to transfer other
+ // parts of the placement and/or cost model.
+ return Status::OK();
+}
+
+Status SimpleGraphExecutionState::InitBaseGraph() {
+ std::unique_ptr<Graph> new_base(new Graph(ops_));
+ GraphConstructorOptions opts;
+ TF_RETURN_IF_ERROR(
+ ConvertGraphDefToGraph(opts, original_graph_def_, new_base.get()));
+ for (const Node* n : new_base->nodes()) {
+ VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
+ node_name_to_cost_id_map_[n->name()] = n->cost_id();
+ }
+
+ Status status = PreliminaryPlace(*new_base);
+ if (!status.ok()) {
+ node_name_to_cost_id_map_.clear();
+ return status;
+ }
+ base_ = new_base.release();
+ return Status::OK();
+}
+
+Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
+ NodeDef* out) {
+ NodeNameToCostIdMap::const_iterator iter =
+ node_name_to_cost_id_map_.find(name);
+ if (iter != node_name_to_cost_id_map_.end()) {
+ mutex_lock l(mu_); // could use reader lock
+ const Node* node = placed_->FindNodeId(iter->second);
+ if (node) {
+ *out = node->def();
+ return Status::OK();
+ }
+ }
+ return errors::NotFound("Node name: ", name);
+}
+
+Status SimpleGraphExecutionState::PreliminaryPlace(const Graph& base) {
+ VLOG(1) << "PreliminaryPlace";
+ Graph* ng = new Graph(ops_);
+
+ CopyGraph(base, ng);
+ Status status = DoPlace(ng);
+ if (!status.ok()) {
+ delete ng;
+ } else {
+ delete placed_;
+ placed_ = ng;
+ FreezeStatefulNodes(true /*is_prelim*/);
+ }
+ return status;
+}
+
+void SimpleGraphExecutionState::FreezeStatefulNodes(bool is_prelim) {
+ if (is_prelim) {
+ // During the preliminary placement every stateful Node got placed
+ // somewhere, and we need to remember where, so it doesn't move.
+ for (Node* n : placed_->nodes()) {
+ if (n->op_def().is_stateful()) {
+ stateful_placements_[n->name()] = n->assigned_device_name();
+ }
+ }
+ } else {
+ // During later placements it's possible for new stateful nodes to
+ // appear. They are noticed while we're pinning the pre-existing
+ // stateful nodes to their prior positions, and after they've been
+ // placed this function is entered to record their placements.
+ for (Node* n : missing_stateful_placements_) {
+ CHECK(n->op_def().is_stateful());
+ stateful_placements_[n->name()] = n->assigned_device_name();
+ }
+ missing_stateful_placements_.clear();
+ }
+}
+
+void SimpleGraphExecutionState::PlaceStatefulNodes(Graph* graph) {
+ for (Node* n : graph->nodes()) {
+ if (n->op_def().is_stateful()) {
+ PlaceMap::const_iterator iter = stateful_placements_.find(n->name());
+ if (iter == stateful_placements_.end()) {
+ // NOTE(tucker): I don't understand why this can occur. So far,
+ // I've only seen it in eval instances, started from a checkpoint.
+ missing_stateful_placements_.push_back(n);
+ } else {
+ n->set_assigned_device_name(iter->second);
+ }
+ }
+ }
+}
+
+Status SimpleGraphExecutionState::DoPlace(Graph* graph) {
+ Status status;
+ // TODO(mrry): Port other placement algorithms from whitepaper.
+ return SimplePlacement(graph);
+}
+
+Status SimpleGraphExecutionState::BuildGraph(const BuildGraphOptions& options,
+ ClientGraph** out) {
+ VLOG(1) << "BuildGraph";
+ mutex_lock l(mu_);
+ // Lazily initialize the base graph.
+ if (base_ == nullptr) {
+ TF_RETURN_IF_ERROR(InitBaseGraph());
+ }
+
+ if (!base_ || !placed_) {
+ return ::tensorflow::errors::Internal(
+ "There was a problem building the graph.");
+ }
+
+ std::unique_ptr<ClientGraph> cgraph(new ClientGraph(ops_));
+ CopyGraph(*placed_, &cgraph->graph);
+
+ // Extract the subset of the graph that needs to be run, adding feed/fetch
+ // ops as needed.
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ &cgraph->graph, options.feed_endpoints, options.fetch_endpoints,
+ options.target_nodes, device_set_->client_device()->attributes()));
+
+ // Copy the extracted graph in order to make its node ids dense,
+ // since the local CostModel used to record its stats is sized by
+ // the largest node id.
+ {
+ std::unique_ptr<ClientGraph> dense_copy(new ClientGraph(ops_));
+ CopyGraph(cgraph->graph, &dense_copy->graph);
+ cgraph = std::move(dense_copy);
+ }
+
+ // TODO(vrv): We should check invariants of the graph here.
+
+ *out = cgraph.release();
+
+ return Status::OK();
+}
+
+Status SimpleGraphExecutionState::DeviceIsCompatible(
+ Node* n, const Device* device) const {
+ if (!n->def().device().empty()) {
+ DeviceNameUtils::ParsedName pname;
+ if (!DeviceNameUtils::ParseFullName(n->def().device(), &pname)) {
+ return AttachDef(
+ errors::InvalidArgument("Malformed device specification '",
+ n->def().device(), "'"),
+ n->def());
+ }
+ std::vector<Device*> devices;
+ device_set_->FindMatchingDevices(pname, &devices);
+ for (auto d : devices) {
+ if (d == device) {
+ return Status::OK();
+ }
+ }
+
+ return AttachDef(
+ errors::InvalidArgument(
+ "Specified device '", n->def().device(),
+ "' not compatible with device of ref connection: ", device->name()),
+ n->def());
+ }
+ return Status::OK();
+}
+
+Status SimpleGraphExecutionState::SimplePlacement(Graph* graph) {
+ SimplePlacer placer(graph, device_set_, &node_name_to_cost_id_map_,
+ session_options_);
+ // TODO(mrry): Consider making the SimplePlacer cancelable.
+ return placer.Run();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.h b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h
new file mode 100644
index 0000000000..6d065437d8
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h
@@ -0,0 +1,156 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_set.h"
+#include "tensorflow/core/distributed_runtime/build_graph_options.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+class SessionOptions;
+class StepStats;
+class Timeline;
+
+struct SimpleGraphExecutionStateOptions {
+ const DeviceSet* device_set = nullptr;
+ const SessionOptions* session_options = nullptr;
+};
+
+// A ClientGraph is simply a sub-graph of the full graph as induced by
+// BuildGraphOptions.
+struct ClientGraph {
+ Graph graph;
+ explicit ClientGraph(const OpRegistryInterface* ops) : graph(ops) {}
+ int32 placement_version;
+};
+
+// SimpleGraphExecutionState is responsible for generating an
+// executable ClientGraph from the original GraphDef that specifies
+// the complete graph and from BuildGraphOptions which specifies
+// input/output nodes.
+//
+// An executable Graph differs from a GraphDef by being Placed,
+// meaning that each Node is assigned to a single Device in the
+// available set.
+//
+// When SimpleGraphExecutionState is first constructed it instantiates
+// a full Graph from the provided GraphDef, and places it, using only
+// the static device assignments from the GraphDef. Nodes without are
+// currently placed in a very naive way. Since stateful Nodes cannot
+// be moved after initial placement, it is important that stateful
+// Nodes get sensible initial device assignments in the graph
+// definition.
+//
+// Subsequently, SimpleGraphExecutionState generates a ClientGraph on
+// demand, which is a sub-graph of the latest placement of the full
+// Graph. MasterSession uses such a ClientGraph to execute one or
+// more similar client requests.
+//
+// SimpleGraphExecutionState is thread-safe.
+
+class SimpleGraphExecutionState {
+ public:
+ SimpleGraphExecutionState(const OpRegistryInterface* ops,
+ const SimpleGraphExecutionStateOptions& options);
+
+ virtual ~SimpleGraphExecutionState();
+
+ // Initializes the SimpleGraphExecutionState with 'graph_def'. Can only be
+ // called once on an original SimpleGraphExecutionState. Callee may modify
+ // 'graph_def'.
+ Status Create(GraphDef* graph_def);
+
+ // Creates a new SimpleGraphExecutionState representing the
+ // concatenation of this graph, and the graph defined by
+ // "extension_def". The same name may not be used to define a node
+ // in both this graph and "extension_def".
+ //
+ // If successful, returns OK and the caller takes ownership of "*out".
+ // Otherwise returns an error and does not modify "*out".
+ //
+ // NOTE(mrry): This method respects the placement of stateful nodes in
+ // in *this, but currently does not transfer any other placement
+ // or cost model information to the new graph.
+ Status Extend(const GraphDef& extension_def,
+ SimpleGraphExecutionState** out) const;
+
+ // Builds a ClientGraph (a sub-graph of the full graph as induced by
+ // the Node set specified in "options"). If successful, returns OK
+ // and the caller takes the ownership of "*out". Otherwise, returns
+ // an error.
+ Status BuildGraph(const BuildGraphOptions& options, ClientGraph** out);
+
+ // Returns OK if the named node is found in the placed full graph owned
+ // by this execution_state, and sets *out to the NodeDef for that node.
+ // It may not exist if name is of a Node added for a particular subgraph
+ // execution, e.g. a send, recv or feed node.
+ Status GlobalNodeDefByName(const string& name, NodeDef* out);
+
+ private:
+ mutable mutex mu_;
+
+ Status InitBaseGraph() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status PreliminaryPlace(const Graph& graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void FreezeStatefulNodes(bool is_prelim) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void PlaceStatefulNodes(Graph* graph) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ Status DoPlace(Graph* graph);
+ Status SimplePlacement(Graph* graph);
+ // Return an OK status if "n" can be assigned to "device".
+ Status DeviceIsCompatible(Node* n, const Device* device) const;
+
+ const OpRegistryInterface* const ops_; // Not owned
+ GraphDef original_graph_def_; // Immutable after ctor.
+ const DeviceSet* device_set_; // Not owned
+ const SessionOptions* session_options_; // Not owned
+
+ // Original graph before we make any placement decisions.
+ Graph* base_ GUARDED_BY(mu_);
+
+ // Full graph, placed on the complete set of devices, as a whole.
+ Graph* placed_ GUARDED_BY(mu_);
+
+ // Map of placed stateful nodes, i.e. nodes for which is_stateful()
+ // is true, such as "params" and "queue" nodes. Once placed these
+ // nodes can not be moved to a different device. Maps node names to
+ // device names.
+ typedef std::unordered_map<string, string> PlaceMap;
+ PlaceMap stateful_placements_ GUARDED_BY(mu_);
+ std::vector<Node*> missing_stateful_placements_ GUARDED_BY(mu_);
+
+ // Map from name to Node for the full graph in placed_.
+ NodeNameToCostIdMap node_name_to_cost_id_map_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_SIMPLE_GRAPH_EXECUTION_STATE_H_
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
new file mode 100644
index 0000000000..5c20636dea
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -0,0 +1,75 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/worker_interface.h" // for CallOptions
+#include "tensorflow/core/framework/device_attributes.pb.h" // for BusAdjacency
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+typedef std::function<void(const Status&)> StatusCallback;
+
+class ChannelCache;
+class StepStats;
+class WorkerInterface;
+
+class WorkerCacheInterface {
+ public:
+ virtual ~WorkerCacheInterface() {}
+
+ // Updates *workers with strings naming the remote worker tasks to
+ // which open channels have been established.
+ virtual void ListWorkers(std::vector<string>* workers) = 0;
+
+ // If "target" names a remote task for which an RPC channel exists
+ // or can be constructed, returns a new WorkerInterface object
+ // wrapping that channel. Ownership passes to the caller.
+ // TODO(tucker): rename this to CreateWorker() or something that
+ // makes it more obvious this is a constructor that transfers
+ // ownership, not a cache lookup.
+ virtual WorkerInterface* CreateWorker(const string& target) = 0;
+
+ // Set *ba with the BusAdjacency of the specified remote device
+ // within its local environment. Returns true if the device bus
+ // affinity was set, using only locally cached data. Returns false
+ // if status data for that device was not available. Never blocks.
+ // TODO(mrry,tucker): Maybe remove.
+ virtual bool GetDeviceBusNonBlocking(const string& device,
+ BusAdjacency* ba) = 0;
+
+ // Set *ba with the BusAdjacency of the specified remote device
+ // within its local environment. Callback gets Status::OK if the
+ // device bus affinity was set.
+ // TODO(mrry,tucker): Maybe remove.
+ virtual void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
+ StatusCallback done) = 0;
+
+ // Start/stop logging activity.
+ virtual void SetLogging(bool active) {}
+
+ // Discard any saved log data.
+ virtual void ClearLogs() {}
+
+ // Return logs for the identified step in *ss. Any returned data will no
+ // longer be stored.
+ virtual bool RetrieveLogs(int64 step_id, StepStats* ss) { return false; }
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_H_
diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.cc b/tensorflow/core/distributed_runtime/worker_cache_logger.cc
new file mode 100644
index 0000000000..bd523ae03a
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache_logger.cc
@@ -0,0 +1,110 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/worker_cache_logger.h"
+
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+// Maximum number of step_ids for which RPC logs can be maintained.
+// TODO(mrry): Make this configurable if necessary.
+const int32 kWorkerCacheLoggerLimit = 1 << 10;
+} // namespace
+
+void WorkerCacheLogger::SetLogging(bool v) {
+ mutex_lock l(count_mu_);
+ if (v) {
+ ++want_logging_count_;
+ } else {
+ --want_logging_count_;
+ // If RPCs get cancelled, it may be possible for the count
+ // to go negative. This should not be a fatal error, since
+ // logging is non-critical.
+ if (want_logging_count_ < 0) want_logging_count_ = 0;
+ }
+}
+
+void WorkerCacheLogger::ClearLogs() {
+ mutex_lock l(mu_);
+ ClearLogsWithLock();
+}
+
+void WorkerCacheLogger::ClearLogsWithLock() {
+ for (auto& iter : log_map_) {
+ delete iter.second.collector;
+ }
+ log_map_.clear();
+}
+
+bool WorkerCacheLogger::RetrieveLogs(int64 step_id, StepStats* ss) {
+ mutex_lock l(mu_);
+ LogMap::iterator iter = log_map_.find(step_id);
+ if (iter != log_map_.end()) {
+ iter->second.collector->Swap(ss);
+ delete iter->second.collector;
+ log_map_.erase(iter);
+ return true;
+ }
+ return false;
+}
+
+void WorkerCacheLogger::Save(const string& device, int64 step_id,
+ NodeExecStats* ns) {
+ mutex_lock l(mu_);
+ StepLog* sl = &log_map_[step_id];
+ if (!sl->collector) {
+ sl->collector = new StepStatsCollector(&sl->step_stats);
+ }
+ sl->collector->Save(device, ns);
+ if (log_map_.size() > kWorkerCacheLoggerLimit) {
+ // Something's gone wrong. Just empty the cache.
+ ClearLogsWithLock();
+ }
+}
+
+void WorkerCacheLogger::RecordRecvTensor(int64 step_id, int64 start_usecs,
+ int64 end_usecs,
+ const string& tensor_name,
+ const string& src_device,
+ const string& dst_device,
+ int64 bytes) {
+ NodeExecStats* ns = new NodeExecStats;
+ ns->set_node_name("RecvTensor");
+ string byte_string = strings::StrCat("[", bytes, "B] ");
+ if (bytes >= 0.1 * 1048576.0) {
+ byte_string = strings::Printf("[%.1fMB] ", bytes / 1048576.0);
+ }
+ ns->set_timeline_label(strings::StrCat(byte_string, tensor_name, " from ",
+ src_device, " to ", dst_device));
+ ns->set_all_start_micros(start_usecs);
+ ns->set_op_start_rel_micros(0);
+ ns->set_op_end_rel_micros(end_usecs - start_usecs);
+ NodeOutput* no = ns->add_output();
+ no->set_slot(0);
+ // TODO(tucker): Maybe set the dimensions too, but then they'll
+ // need to be passed in.
+ no->mutable_tensor_description()
+ ->mutable_allocation_description()
+ ->set_requested_bytes(bytes);
+ Save(dst_device, step_id, ns);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker_cache_logger.h b/tensorflow/core/distributed_runtime/worker_cache_logger.h
new file mode 100644
index 0000000000..46ba8a33ba
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache_logger.h
@@ -0,0 +1,81 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+class StepStatsCollector;
+
+// WorkerCacheLogger is a thread-safe utility for use by a WorkerCache
+// to optionally log some selected RPC activity. A single instance
+// should be owned by a WorkerCache, for use by its RemoteWorker
+// instances.
+
+class WorkerCacheLogger {
+ public:
+ // Start/Stop logging activity. This function increments/decrements
+ // a counter so that if two separate steps turn logging on/off,
+ // logging should be on for the union of the durations of both,
+ // regardless of relative timing.
+ void SetLogging(bool v);
+
+ // Discard any saved log data.
+ void ClearLogs();
+
+ // Return logs for the identified step in *ss. Any returned data will no
+ // longer be stored. Returns true iff *ss was modified.
+ bool RetrieveLogs(int64 step_id, StepStats* ss);
+
+ // Return true if there is any outstanding request for logging on
+ // the RPC channels.
+ bool LoggingActive() {
+ mutex_lock l(count_mu_);
+ return want_logging_count_ > 0;
+ }
+
+ // Generates a NodeExecStats record with the given data, and saves for
+ // later retrieval by RetrieveLogs().
+ void RecordRecvTensor(int64 step_id, int64 start_usecs, int64 end_usecs,
+ const string& tensor_name, const string& src_device,
+ const string& dst_device, int64 bytes);
+
+ private:
+ mutex count_mu_;
+ int32 want_logging_count_ GUARDED_BY(count_mu_);
+
+ struct StepLog {
+ StepStats step_stats;
+ StepStatsCollector* collector;
+ };
+ typedef std::unordered_map<int64, StepLog> LogMap;
+ mutex mu_;
+ LogMap log_map_ GUARDED_BY(mu_);
+
+ // Records "ns" in log_map_ under the given device and step.
+ void Save(const string& device, int64 step_id, NodeExecStats* ns);
+
+ void ClearLogsWithLock() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+};
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_LOGGER_H_
diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.cc b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
new file mode 100644
index 0000000000..62c73b5fd9
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache_partial.cc
@@ -0,0 +1,98 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/worker_cache_partial.h"
+
+#include "tensorflow/core/distributed_runtime/process_util.h"
+#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+
+bool WorkerCachePartial::GetDeviceBusNonBlocking(const string& device_name,
+ BusAdjacency* ba) {
+ mutex_lock lock(mu_); // could use reader lock
+ const auto& iter = device_status_cache_.find(device_name);
+ if (iter != device_status_cache_.end()) {
+ *ba = iter->second.bus_adjacency();
+ return true;
+ }
+ return false;
+}
+
+void WorkerCachePartial::GetDeviceBusAsync(const string& device_name,
+ BusAdjacency* ba,
+ StatusCallback done) {
+ if (!GetDeviceBusNonBlocking(device_name, ba)) {
+ // If cache entry was empty, make one try to fill it by RPC.
+ SchedClosure([this, &device_name, ba, done]() {
+ Status s = RefreshDeviceStatus(device_name);
+ if (s.ok()) {
+ if (!GetDeviceBusNonBlocking(device_name, ba)) {
+ mutex_lock lock(mu_);
+ const auto& iter = device_status_cache_.find(device_name);
+ if (iter == device_status_cache_.end()) {
+ s = errors::Unavailable("No known remote device: ", device_name);
+ } else {
+ s = errors::Internal("Failed to find bus_adjacency for ",
+ device_name);
+ }
+ }
+ }
+ done(s);
+ });
+ return;
+ }
+ done(Status::OK());
+}
+
+Status WorkerCachePartial::RefreshDeviceStatus(const string& device_name) {
+ string task;
+ string device;
+ Status s;
+ if (!DeviceNameUtils::SplitDeviceName(device_name, &task, &device)) {
+ s = errors::InvalidArgument("Bad device name to RefreshDeviceStatus: ",
+ device_name);
+ }
+ std::unique_ptr<WorkerInterface> rwi(CreateWorker(task));
+ if (s.ok() && !rwi.get()) {
+ s = errors::Internal("RefreshDeviceStatus, unknown worker task: ", task);
+ }
+
+ if (s.ok()) {
+ GetStatusRequest req;
+ GetStatusResponse resp;
+ s = rwi->GetStatus(&req, &resp);
+ if (s.ok()) {
+ mutex_lock lock(mu_);
+ for (auto& dev_attr : resp.device_attributes()) {
+ device_status_cache_[dev_attr.name()] = dev_attr;
+ }
+ }
+ }
+ return s;
+}
+
+void WorkerCachePartial::FlushStatusCache() {
+ mutex_lock lock(mu_);
+ device_status_cache_.clear();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/worker_cache_partial.h b/tensorflow/core/distributed_runtime/worker_cache_partial.h
new file mode 100644
index 0000000000..5d8a56648d
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_cache_partial.h
@@ -0,0 +1,56 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
+
+#include <string>
+#include <unordered_map>
+
+#include "tensorflow/core/distributed_runtime/worker_cache.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+// Implements the part of the interface that caches and returns remote
+// device status attributes.
+class WorkerCachePartial : public WorkerCacheInterface {
+ public:
+ bool GetDeviceBusNonBlocking(const string& device, BusAdjacency* ba) override;
+
+ void GetDeviceBusAsync(const string& device, BusAdjacency* ba,
+ StatusCallback) override;
+
+ ~WorkerCachePartial() override {}
+
+ // Clear all entries from the DeviceStatus cache.
+ void FlushStatusCache();
+
+ private:
+ mutex mu_;
+
+ // Initiate a GetStatusAsync to the remote task named by "task", and
+ // update the cache with all the DeviceAttributes reported.
+ Status RefreshDeviceStatus(const string& device_name);
+
+ typedef std::unordered_map<string, DeviceAttributes> StatusMap;
+ StatusMap device_status_cache_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_CACHE_PARTIAL_H_
diff --git a/tensorflow/core/distributed_runtime/worker_env.h b/tensorflow/core/distributed_runtime/worker_env.h
new file mode 100644
index 0000000000..d26462570e
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_env.h
@@ -0,0 +1,62 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace thread {
+class ThreadPool;
+} // namespace thread
+
+class DeviceMgr;
+class Env;
+class GraphMgr;
+class RendezvousMgrInterface;
+class WorkerCacheInterface;
+
+// The worker environment class, which holds a bag of pointers to
+// per-worker singletons.
+//
+// WorkerEnv does not own its member pointers.
+struct WorkerEnv {
+ Env* env = nullptr;
+
+ // The name of the worker. E.g., /job:mnist/replica:1/task:0.
+ string worker_name;
+
+ // Object from which WorkerInterface instances can be obtained.
+ WorkerCacheInterface* worker_cache = nullptr;
+
+ // device_mgr manages local devices (cpu and gpu). The WorkerService
+ // is the network interface for managed devices.
+ DeviceMgr* device_mgr = nullptr;
+
+ // graph_mgr keeps track of registered graphs of this worker.
+ GraphMgr* graph_mgr = nullptr;
+
+ // A set of rendezvous keyed by step ids.
+ RendezvousMgrInterface* rendezvous_mgr = nullptr;
+
+ // A pool of threads for scheduling compute work.
+ thread::ThreadPool* compute_pool = nullptr;
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
diff --git a/tensorflow/core/distributed_runtime/worker_interface.h b/tensorflow/core/distributed_runtime/worker_interface.h
new file mode 100644
index 0000000000..6e68d300e6
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/worker_interface.h
@@ -0,0 +1,129 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
+
+#include <functional>
+
+#include "tensorflow/core/distributed_runtime/call_options.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
+
+namespace tensorflow {
+
+// Status callback.
+typedef std::function<void(const Status&)> StatusCallback;
+
+// Allocator callback for out-of-band transfers.
+class TensorShape;
+typedef std::function<void*(size_t, const DataType&, const TensorShape&)>
+ TensorBufAllocator;
+
+// Interface for talking with the TensorFlow Worker service.
+class WorkerInterface {
+ public:
+ virtual ~WorkerInterface() {}
+
+ virtual void GetStatusAsync(const GetStatusRequest* request,
+ GetStatusResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
+ RegisterGraphResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
+ DeregisterGraphResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
+ RunGraphResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
+ CleanupGraphResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void CleanupAllAsync(const CleanupAllRequest* request,
+ CleanupAllResponse* response,
+ StatusCallback done) = 0;
+
+ virtual void RecvTensorAsync(CallOptions* opts,
+ const RecvTensorRequest* request,
+ RecvTensorResponse* response,
+ TensorBufAllocator allocator,
+ StatusCallback done) = 0;
+
+ virtual void LoggingAsync(const LoggingRequest* request,
+ LoggingResponse* response, StatusCallback done) = 0;
+
+ virtual void TracingAsync(const TracingRequest* request,
+ TracingResponse* response, StatusCallback done) = 0;
+
+ Status GetStatus(const GetStatusRequest* request,
+ GetStatusResponse* response) {
+ return CallAndWait(&ME::GetStatusAsync, request, response);
+ }
+
+ Status RegisterGraph(const RegisterGraphRequest* request,
+ RegisterGraphResponse* response) {
+ return CallAndWait(&ME::RegisterGraphAsync, request, response);
+ }
+
+ Status DeregisterGraph(const DeregisterGraphRequest* request,
+ DeregisterGraphResponse* response) {
+ return CallAndWait(&ME::DeregisterGraphAsync, request, response);
+ }
+
+ Status CleanupGraph(const CleanupGraphRequest* request,
+ CleanupGraphResponse* response) {
+ return CallAndWait(&ME::CleanupGraphAsync, request, response);
+ }
+
+ Status CleanupAll(const CleanupAllRequest* request,
+ CleanupAllResponse* response) {
+ return CallAndWait(&ME::CleanupAllAsync, request, response);
+ }
+
+ Status Logging(const LoggingRequest* request, LoggingResponse* response) {
+ return CallAndWait(&ME::LoggingAsync, request, response);
+ }
+
+ Status Tracing(const TracingRequest* request, TracingResponse* response) {
+ return CallAndWait(&ME::TracingAsync, request, response);
+ }
+
+ private:
+ typedef WorkerInterface ME;
+
+ template <typename Method, typename Req, typename Resp>
+ Status CallAndWait(Method func, const Req* req, Resp* resp) {
+ Status ret;
+ Notification n;
+ (this->*func)(req, resp, [&ret, &n](const Status& s) {
+ ret = s;
+ n.Notify();
+ });
+ n.WaitForNotification();
+ return ret;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_INTERFACE_H_
diff --git a/tensorflow/core/framework/load_library.cc b/tensorflow/core/framework/load_library.cc
index 92a3400185..bd21ac8e34 100644
--- a/tensorflow/core/framework/load_library.cc
+++ b/tensorflow/core/framework/load_library.cc
@@ -65,7 +65,7 @@ Status LoadLibrary(const char* library_filename, void** result,
string str;
GetOpList(&str);
char* str_buf = reinterpret_cast<char*>(operator new(str.length()));
- strncpy(str_buf, str.data(), str.length());
+ memcpy(str_buf, str.data(), str.length());
*buf = str_buf;
*len = str.length();
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index efee64a7a8..633441f31b 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -25,37 +25,58 @@ def tf_deps(deps, suffix):
return tf_deps
-def tf_proto_library(name, srcs = [], has_services = False,
- deps = [], visibility = [], testonly = 0,
- cc_api_version = 2, go_api_version = 2,
- java_api_version = 2,
- py_api_version = 2):
+def tf_proto_library_cc(name, srcs = [], has_services = None,
+ deps = [], visibility = [], testonly = 0,
+ cc_libs = [],
+ cc_stubby_versions = None,
+ cc_grpc_version = None,
+ cc_api_version = 2, go_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2):
native.filegroup(name=name + "_proto_srcs",
srcs=srcs + tf_deps(deps, "_proto_srcs"),
testonly=testonly,)
+ use_grpc_plugin = None
+ if cc_grpc_version:
+ use_grpc_plugin = True
cc_proto_library(name=name + "_cc",
srcs=srcs + tf_deps(deps, "_proto_srcs"),
deps=deps + ["//google/protobuf:cc_wkt_protos"],
- cc_libs = ["//google/protobuf:protobuf"],
+ cc_libs = cc_libs + ["//google/protobuf:protobuf"],
+ use_grpc_plugin = use_grpc_plugin,
testonly=testonly,
visibility=visibility,)
- py_proto_library(name=name + "_py",
- srcs=srcs + tf_deps(deps, "_proto_srcs"),
- srcs_version="PY2AND3",
- deps=deps + ["//google/protobuf:protobuf_python"],
- testonly=testonly,
- visibility=visibility,)
-
-def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0):
+def tf_proto_library_py(name, srcs=[], deps=[], visibility=[], testonly=0,
+ srcs_version="PY2AND3"):
py_proto_library(name = name + "_py",
srcs = srcs,
- srcs_version = "PY2AND3",
+ srcs_version = srcs_version,
deps = deps,
visibility = visibility,
testonly = testonly)
+def tf_proto_library(name, srcs = [], has_services = None,
+ deps = [], visibility = [], testonly = 0,
+ cc_libs = [],
+ cc_api_version = 2, go_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2):
+ tf_proto_library_cc(name=name,
+ srcs=srcs + tf_deps(deps, "_proto_srcs"),
+ deps=deps,
+ cc_libs=cc_libs,
+ testonly=testonly,
+ visibility=visibility,)
+
+ tf_proto_library_py(name=name,
+ srcs=srcs + tf_deps(deps, "_proto_srcs"),
+ srcs_version="PY2AND3",
+ deps=deps + ["//google/protobuf:protobuf_python"],
+ testonly=testonly,
+ visibility=visibility,)
+
def tf_additional_lib_srcs():
return [
"platform/default/*.h",
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
new file mode 100644
index 0000000000..e46581bdab
--- /dev/null
+++ b/tensorflow/core/protobuf/master.proto
@@ -0,0 +1,190 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+option java_outer_classname = "DistributedRuntimeProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+import "tensorflow/core/framework/config.proto";
+import "tensorflow/core/framework/device_attributes.proto";
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor.proto";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// CreateSession method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message CreateSessionRequest {
+ // The initial graph definition.
+ GraphDef graph_def = 1;
+
+ // Configuration options.
+ ConfigProto config = 2;
+}
+
+message CreateSessionResponse {
+ // The session handle to be used in subsequent calls for the created session.
+ //
+ // The client must arrange to call CloseSession with this returned
+ // session handle to close the session.
+ string session_handle = 1;
+
+ // The initial version number for the graph, to be used in the next call
+ // to ExtendSession.
+ int64 graph_version = 2;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// ExtendSession method request/response protos.
+//
+// The "graph_def" specifies a set of nodes to be added to the session's graph.
+//
+// A typical "graph_def" will contain:
+//
+// * Zero or more new nodes with names that do not exist in the server-side
+// graph. These will be added to the graph.
+//
+// PRECONDITION: The server-side current version is req.current_version.
+// None of the names in req.graph_def appeared in previous successful calls to
+// CreateSession or ExtendSession with the same session_handle.
+// POSTCONDITION: The server-side current version is resp.new_version.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message ExtendSessionRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // REQUIRED: The nodes to be added to the session's graph. If any node has
+ // the same name as an existing node, the operation will fail with
+ // ILLEGAL_ARGUMENT.
+ GraphDef graph_def = 2;
+
+ // REQUIRED: The version number of the graph to be extended. This will be
+ // tested against the current server-side version number, and the operation
+ // will fail with FAILED_PRECONDITION if they do not match.
+ int64 current_graph_version = 3;
+}
+
+message ExtendSessionResponse {
+ // TODO(mrry): Return something about the operation?
+
+ // The new version number for the extended graph, to be used in the next call
+ // to ExtendSession.
+ int64 new_graph_version = 4;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RunStep method request/response protos.
+//
+// The caller should provide the feeds needed by the graph and specify
+// what nodes should be fetched.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// A pair of tensor name and tensor values.
+message NamedTensorProto {
+ // Name of the tensor.
+ string name = 1;
+
+ // The client can populate a TensorProto using a tensorflow::Tensor`, or
+ // directly using the protobuf field accessors.
+ //
+ // The client specifies whether the returned tensor values should be
+ // filled tensor fields (float_val, int_val, etc.) or encoded in a
+ // compact form in tensor.tensor_content.
+ TensorProto tensor = 2;
+}
+
+message RunStepRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // Tensors to be fed in the step. Each feed is a named tensor.
+ repeated NamedTensorProto feed = 2;
+
+ // Fetches. A list of tensor names. The caller expects a tensor to
+ // be returned for each fetch[i] (see RunStepResponse.tensor). The
+ // order of specified fetches does not change the execution order.
+ repeated string fetch = 3;
+
+ // Target Nodes. A list of node names. The named nodes will be run
+ // to but their outputs will not be fetched.
+ repeated string target = 4;
+}
+
+message RunStepResponse {
+ // NOTE: The order of the returned tensors may or may not match
+ // the fetch order specified in RunStepRequest.
+ repeated NamedTensorProto tensor = 1;
+
+ // TODO(mrry): Optionally aggregate StepStats in some form here.
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// CloseSession method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message CloseSessionRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+}
+
+message CloseSessionResponse {
+}
+
+message ResetRequest {
+ // A list of container names, which may be empty.
+ //
+ // If 'container' is not empty, releases resoures in the given
+ // containers in all devices.
+ //
+ // If 'container' is empty, releases resources in the default
+ // container in all devices.
+ repeated string container = 1;
+}
+
+message ResetResponse {
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// ListDevices method request/response protos.
+//
+// Returns information about the TensorFlow devices that are available
+// to this master.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message ListDevicesRequest {
+}
+
+message ListDevicesResponse {
+ repeated DeviceAttributes local_device = 1;
+ repeated DeviceAttributes remote_device = 2;
+}
diff --git a/tensorflow/core/protobuf/master_service.proto b/tensorflow/core/protobuf/master_service.proto
new file mode 100644
index 0000000000..13b0a97b11
--- /dev/null
+++ b/tensorflow/core/protobuf/master_service.proto
@@ -0,0 +1,105 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow.grpc;
+option java_outer_classname = "MasterServiceProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+import "tensorflow/core/protobuf/master.proto";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// MasterService defines a TensorFlow service with which a client can
+// interact to execute a distributed TensorFlow computation.
+//
+// A master service keeps track of multiple "master sessions". Each
+// session encapsulates a computation graph and its associated state,
+// and typically corresponds to a single "client session" (e.g. a
+// `tensorflow::Session` instance).
+//
+// A session is responsible for the following:
+// * assigning each node to a device (locally or remotely) using a
+// placement algorithm. This may make decisions based on collected
+// statistics from the workers in the system (e.g., memory usage,
+// bandwidth consumption, etc.)
+//
+// * inserting intermediate nodes and edges to support cross-device
+// and cross-process data flows and resource management.
+//
+// * issuing commands to workers to execute the subgraphs associated
+// with those workers.
+//
+// Typically, a client carries out an iterative computation
+// (e.g. training) by invoking RPCs against the master in a
+// client-side loop. The client first creates a client session that
+// connects to a particular master (using gRPC for example). The
+// master creates a corresponding master session that is hosted on
+// the master and caches state between the client's invocations.
+//
+// After the session is established, the master returns an opaque
+// handle to the client that can be used to associate the client and
+// master sessions.
+//
+// The client may send an initial graph to the master in the
+// CreateSession call, and add nodes to the graph using ExtendSession.
+//
+// The most frequent operation a master is "RunStep", which implements
+// the `Session::Run()` API. It supports feeding in arguments,
+// executing a dataflow computation, and fetching arguments.
+//
+// Finally, when the client no longer needs the session, it should
+// close the session by invoking CloseSession, which allows the master
+// to reclaim resources associated with the session. The master may
+// implement a garbage collection scheme that closes sessions that
+// have been inactive for some time.
+//
+// For example, the following pseudo-code illustrates how a client
+// interacts with a master:
+//
+// stub = NewStub("/job:mnist/replica:0/task:0")
+// {handle} = stub->CreateSession({graph_def})
+// do {
+// stub->RunStep({handle, {feeds}, {fetches}})
+// // The client can evaluate a predicate locally, based on the
+// // result of `fetches`, to determine whether to terminate. For
+// // example, it might fetch the loss and evaluate whether it is less
+// // than some threshold.
+// } whlie (!should_stop({fetches}));
+// stub->CloseSession({handle})
+//
+////////////////////////////////////////////////////////////////////////////////
+
+service MasterService {
+ // Creates a session.
+ rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse);
+
+ // Extends a session.
+ rpc ExtendSession(ExtendSessionRequest) returns (ExtendSessionResponse);
+
+ // Drives the graph computation.
+ rpc RunStep(RunStepRequest) returns (RunStepResponse);
+
+ // Closes a session.
+ rpc CloseSession(CloseSessionRequest) returns (CloseSessionResponse);
+
+ // List the devices usable by the master.
+ rpc ListDevices(ListDevicesRequest) returns (ListDevicesResponse);
+
+ // Close all existing sessions.
+ rpc Reset(ResetRequest) returns (ResetResponse);
+}
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
new file mode 100644
index 0000000000..bb01b65d8b
--- /dev/null
+++ b/tensorflow/core/protobuf/worker.proto
@@ -0,0 +1,311 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow;
+// option cc_enable_arenas = true;
+option java_outer_classname = "WorkerProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+import "google/protobuf/any.proto";
+import "tensorflow/core/framework/config.proto";
+import "tensorflow/core/framework/step_stats.proto";
+import "tensorflow/core/framework/device_attributes.proto";
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor.proto";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// GetStatus method request/response messages
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message GetStatusRequest {
+}
+
+message GetStatusResponse {
+ repeated DeviceAttributes device_attributes = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RegisterGraph method request/response messages
+//
+// For each session, after the master placed every node on a device,
+// it partitions the whole graph into many subgraphs. All the nodes in
+// a subgraph were in the same worker, but potentially on many devices
+// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
+// master registers subgraphs for a worker before running any steps. A
+// successful registration returns a graph handle to be used in latter
+// RunGraph requests.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RegisterGraphRequest {
+ // Subgraphs are scoped within one session.
+ string session_handle = 1;
+
+ // "graph_def" has the subgraph of nodes for this worker, with each node
+ // having its device_name filled in.
+ GraphDef graph_def = 2;
+
+ // True iff the graph (before partitioning) contains control flow nodes.
+ //
+ // As of 01/11/2015, this is no longer set by clients.
+ bool has_control_flow = 3 [deprecated = true];
+
+ // Configuration options for the session in which this graph was created.
+ GraphOptions graph_options = 4;
+}
+
+message RegisterGraphResponse {
+ // If the registration succeeds, returns an opaque graph_handle to
+ // the master. The master calls RunGraph with graph_handle to
+ // compute different steps.
+ string graph_handle = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// DeregisterGraph method request/response messages
+//
+// The master deregisters the given graph_handle when the graph is no
+// longer needed (e.g., the overall graph is re-scheduled and nodes
+// are re-placed).
+//
+// The worker deregisters a graph_handle automatically according to on
+// a TTL-base policy in case of master restarts.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message DeregisterGraphRequest {
+ // REQUIRED: graph_handle must be returned by a RegisterGraph call
+ // to the same WorkerService.
+ string graph_handle = 1;
+}
+
+message DeregisterGraphResponse {
+ // TODO(mrry): Optionally add summary stats for the graph.
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// CleanupAll method request/response messages
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message CleanupAllRequest {
+ // A list of container names.
+ //
+ // If 'container' is not empty, releases resoures in the given
+ // containers in all devices.
+ //
+ // If 'container' is empty, releases resources in the default
+ // container in all devices.
+ repeated string container = 1;
+}
+
+message CleanupAllResponse {
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RunGraph request / response messages
+//
+// The worker executes all subgraphs registered under graph_handle.
+// RunGraph returns after the execution finishes or an error is
+// encountered.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// A pair of tensor name and tensor values.
+message NamedTensor {
+ // The name of the named tensor.
+ string key = 1;
+
+ // The value of the named tensor.
+ TensorProto val = 2;
+}
+
+// Options specific to the execution of a single step.
+message ExecutorOpts {
+ bool record_costs = 1;
+ bool record_timeline = 3;
+};
+
+message RunGraphRequest {
+ // REQUIRED: graph_handle must be returned by a RegisterGraph call
+ // to the same WorkerService.
+ string graph_handle = 1;
+
+ // A unique ID to distinguish different runs of the same graph.
+ //
+ // The master generates a global unique `step_id` to dinstinguish
+ // different runs of the graph computation. Subgraphs communicate
+ // (e.g., send/recv ops) with each other using `step_id` to
+ // distinguish tensors generated by different runs.
+ int64 step_id = 2;
+
+ // Options for this step.
+ ExecutorOpts exec_opts = 5;
+
+ // Runs the graph.
+ //
+ // Sends the tensors in "send" into the graph before the run and
+ // fetches the keys into `RunGraphResponse.recv` after the run.
+ repeated NamedTensor send = 3;
+ repeated string recv_key = 4;
+}
+
+message RunGraphResponse {
+ // A list of tensors corresponding to those requested by
+ // `RunGraphRequest.recv_key`.
+ repeated NamedTensor recv = 1;
+
+ // If the request asked for execution stats, these are returned here.
+ StepStats step_stats = 2;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// CleanupGraph method request/response messages
+//
+// After the master receives RunGraph responses from all workers, the
+// master instructs every worker to cleanup any remaining state of a
+// step (e.g. tensors buffered by a `Send` op but not picked up by
+// other workers). The master does not necessarily need to wait for
+// completion of CleanupGraph calls.
+//
+// Workers should cleanup step states automatically according to a
+// TTL-based policy in case of master restarts.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message CleanupGraphRequest {
+ int64 step_id = 1;
+}
+
+message CleanupGraphResponse {
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RecvTensor method request/response messages
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RecvTensorRequest {
+ // The step in which the tensor will be produced.
+ //
+ // REQUIRED: This must eventually correspond to the `step_id` passed
+ // into a RunGraph call on the same WorkerService.
+ int64 step_id = 1;
+
+ // A key that identifies the tensor to be received.
+ string rendezvous_key = 2;
+
+ // If true, use an out-of-band DMA mechanism to transfer the
+ // received tensor.
+ bool dma_ok = 3;
+ // NIC bus preference on the request originator side
+ BusAdjacency client_bus_adjacency = 4;
+ // NIC bus preference on the request receiver side
+ BusAdjacency server_bus_adjacency = 5;
+}
+
+message RecvTensorResponse {
+ // The tensor as a proto.
+ TensorProto tensor = 1;
+
+ // If true, this tensor was the output of a dead node, and the
+ // content is invalid.
+ bool is_dead = 2;
+
+ // The time at which tensor was available and started to be returned.
+ int64 send_start_micros = 3;
+
+ // Optional additional information about how to receive the tensor,
+ // in the event that `RecvTensorRequest.dma_ok` was true.
+ google.protobuf.Any transport_options = 4;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Logging method request/response messages
+//
+// NOTE(mrry): This feature is not supported in the open-source
+// version, and these messages are expected to change.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+// Out-of-band request to begin or end logging, or
+// to retrieve logs for particular steps.
+message LoggingRequest {
+ // If true, RPC logging will be activated.
+ bool rpc_logging = 1;
+
+ // If true, discard any saved logging data (for all steps).
+ bool clear = 2;
+
+ // When set, requests all saved log data pertaining to the step.
+ // Any log data retrieved is eliminated from the store and cannot be
+ // retrieved again.
+ repeated int64 fetch_step_id = 3;
+}
+
+message LabeledStepStats {
+ int64 step_id = 1;
+ StepStats step_stats = 2;
+}
+
+message LoggingResponse {
+ repeated LabeledStepStats step = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Tracing method request/response messages
+//
+// NOTE(mrry): This feature is not supported in the open-source
+// version, and these messages are expected to change.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message TraceOpts {
+ // Length of the trace to be taken, in seconds.
+ double duration = 1;
+ // If true, capture step profile locally in each worker. Currently
+ // unimplemented.
+ bool use_step_profiler = 2;
+ // If true, capture kernel events from each worker.
+ bool use_kernel_profiler = 3;
+ // If true, capture extended profiling events from TensorFlow process.
+ bool use_extended_profiler = 4;
+ // If true, capture GPU profiling events locally on each
+ // machine. Currently unimplemented.
+ bool use_gpu_profiler = 5;
+ // If true, collect sampled profile events. Currently unimplemented.
+ bool use_sample_profiler = 6;
+}
+
+// Out-of-band request to configure distributed tracing.
+message TracingRequest {
+ TraceOpts options = 1;
+}
+
+message TracingResponse {
+}
diff --git a/tensorflow/core/protobuf/worker_service.proto b/tensorflow/core/protobuf/worker_service.proto
new file mode 100644
index 0000000000..2699e639db
--- /dev/null
+++ b/tensorflow/core/protobuf/worker_service.proto
@@ -0,0 +1,67 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+syntax = "proto3";
+
+package tensorflow.grpc;
+option java_outer_classname = "WorkerServiceProtos";
+option java_multiple_files = true;
+option java_package = "org.tensorflow.distruntime";
+
+import "tensorflow/core/protobuf/worker.proto";
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// WorkerService defines a TensorFlow service that executes dataflow
+// graphs on a set of local devices, on behalf of a MasterService.
+//
+// A worker service keeps track of multiple "registered graphs". Each
+// registered graph is a subgraph of a client's graph, corresponding to
+// only the nodes that should execute on this worker (and any
+// additional nodes necessary for inter-process communication using
+// the `RecvTensor` method).
+//
+////////////////////////////////////////////////////////////////////////////////
+
+service WorkerService {
+ // See worker.proto for details.
+ rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
+
+ // See worker.proto for details.
+ rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
+
+ // See worker.proto for details.
+ rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);
+
+ // See worker.proto for details.
+ rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);
+
+ // See worker.proto for details.
+ rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);
+
+ // See worker.proto for details.
+ rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);
+
+ // See worker.proto for details.
+ rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
+ // RecvTensor Method
+ }
+
+ // See worker.proto for details.
+ rpc Logging(LoggingRequest) returns (LoggingResponse);
+
+ // See worker.proto for details.
+ rpc Tracing(TracingRequest) returns (TracingResponse);
+}
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 2f45b07484..58d00d3c7f 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -920,6 +920,7 @@ tf_py_wrap_cc(
":py_record_writer_lib",
":python_op_gen",
":tf_session_helper",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//util/python:python_headers",
],
)
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 96d61a327c..16d4f287ab 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -750,8 +750,9 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEqual(c_list[1], out[1].decode('utf-8'))
def testInvalidTargetFails(self):
- with self.assertRaisesRegexp(RuntimeError,
- 'Registered factories are {DIRECT_SESSION}'):
+ with self.assertRaisesRegexp(
+ RuntimeError,
+ 'No session factory registered for the given session options.'):
session.Session('INVALID_TARGET')
def testFetchByNameDifferentStringTypes(self):
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index bc5ba95348..9db78bb13a 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -135,7 +135,7 @@ have varying scale, and to aid generalization.
@@l2_normalize
@@local_response_normalization
@@sufficient_statistics
-@@aggregate_moments
+@@normalize_moments
@@moments
## Losses
@@ -561,7 +561,7 @@ def sufficient_statistics(x, axes, shift=True, keep_dims=False, name=None):
return counts, m_ss, v_ss, shift_value
-def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
+def normalize_moments(counts, mean_ss, variance_ss, shift, name=None):
"""Calculate the mean and variance of based on the sufficient statistics.
Args:
@@ -577,7 +577,7 @@ def aggregate_moments(counts, mean_ss, variance_ss, shift, name=None):
Returns:
Two `Tensor` objects: `mean` and `variance`.
"""
- with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "aggregate"):
+ with ops.op_scope([counts, mean_ss, variance_ss, shift], name, "normalize"):
divisor = math_ops.inv(counts, name="divisor")
if shift is not None:
shifted_mean = math_ops.mul(mean_ss, divisor, name="shifted_mean")
@@ -620,7 +620,7 @@ def moments(x, axes, name=None, keep_dims=False):
axes,
keep_dims=keep_dims,
name=name)
- return aggregate_moments(counts, m_ss, v_ss, shift, name=name)
+ return normalize_moments(counts, m_ss, v_ss, shift, name=name)
def batch_normalization(x,
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 317a074830..30c7976909 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -826,19 +826,19 @@ class SufficientStatisticsTest(tf.test.TestCase):
self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
-class AggregateMomentsTest(tf.test.TestCase):
+class NormalizeMomentsTest(tf.test.TestCase):
- def _npAggregateMoments(self, counts, mean_ss, variance_ss, shift):
+ def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
mean = mean_ss / counts
variance = variance_ss / counts - mean * mean
if shift is not None:
mean += shift
return mean, variance
- def _opAggregateMoments(self, counts, mean_ss, variance_ss, shift):
- return tf.nn.aggregate_moments(counts, mean_ss, variance_ss, shift)
+ def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
+ return tf.nn.normalize_moments(counts, mean_ss, variance_ss, shift)
- def _testAggregateMoments(self, shape, shift):
+ def _testNormalizeMoments(self, shape, shift):
counts = np.ones([1]).astype(np.float32)
mean_ss = np.random.random_sample(shape).astype(np.float32)
variance_ss = np.random.random_sample(shape).astype(np.float32)
@@ -847,7 +847,7 @@ class AggregateMomentsTest(tf.test.TestCase):
shift_v = np.random.random_sample(shape).astype(np.float32)
else:
shift_v = None
- npm, npv = self._npAggregateMoments(counts, mean_ss, variance_ss, shift_v)
+ npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu) as sess:
tf_counts = tf.constant(counts, name="counts")
@@ -857,16 +857,16 @@ class AggregateMomentsTest(tf.test.TestCase):
tf_shift_v = tf.constant(shift_v, name="shift")
else:
tf_shift_v = None
- opm, opv = self._opAggregateMoments(tf_counts, tf_mean_ss,
+ opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
tf_variance_ss, tf_shift_v)
tfm, tfv = sess.run([opm, opv])
self.assertAllClose(npm, tfm, atol=0.000001)
self.assertAllClose(npv, tfv, atol=0.000001)
- def testAggregateMoments(self):
+ def testNormalizeMoments(self):
for shift in [True, False]:
- self._testAggregateMoments([3], shift)
- self._testAggregateMoments([2, 3], shift)
+ self._testNormalizeMoments([3], shift)
+ self._testNormalizeMoments([2, 3], shift)
class MomentsTest(tf.test.TestCase):
@@ -971,15 +971,15 @@ class MomentsTest(tf.test.TestCase):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=False)
- self.assertEquals(mean.op.name, "moments/aggregate/mean")
- self.assertEquals(var.op.name, "moments/aggregate/variance")
+ self.assertEquals(mean.op.name, "moments/normalize/mean")
+ self.assertEquals(var.op.name, "moments/normalize/variance")
def testOutputNamesKeep(self):
"""Make sure the output names are stable."""
with self.test_session():
mean, var = tf.nn.moments(tf.constant([1]), [0], keep_dims=True)
- self.assertEquals(mean.op.name, "moments/aggregate/mean")
- self.assertEquals(var.op.name, "moments/aggregate/variance")
+ self.assertEquals(mean.op.name, "moments/normalize/mean")
+ self.assertEquals(var.op.name, "moments/normalize/variance")
class ComputeSampledLogitsTest(tf.test.TestCase):