aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-04-06 17:17:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 17:19:59 -0700
commit5e11bbacaffdf7bc4a9363301de6a0755f95e9c0 (patch)
tree48f37585cd3b01c71eaced8724be21151374264d
parentddf54d1c24a2b4dcfd8eb52d21dc1f393785f1e9 (diff)
Open sourcing proto/rpc ops.
PiperOrigin-RevId: 191962572
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt6
-rw-r--r--tensorflow/contrib/cmake/tf_core_ops.cmake3
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt2
-rw-r--r--tensorflow/contrib/proto/BUILD16
-rw-r--r--tensorflow/contrib/proto/__init__.py28
-rw-r--r--tensorflow/contrib/proto/python/ops/BUILD44
-rw-r--r--tensorflow/contrib/proto/python/ops/decode_proto_op.py25
-rw-r--r--tensorflow/contrib/proto/python/ops/encode_proto_op.py25
-rw-r--r--tensorflow/contrib/rpc/BUILD13
-rw-r--r--tensorflow/contrib/rpc/__init__.py28
-rw-r--r--tensorflow/contrib/rpc/python/ops/BUILD24
-rw-r--r--tensorflow/contrib/rpc/python/ops/rpc_op.py26
-rw-r--r--tensorflow/core/BUILD9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt116
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt81
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt108
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt123
-rw-r--r--tensorflow/core/distributed_runtime/rpc/BUILD30
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc213
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h59
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc34
-rw-r--r--tensorflow/core/kernels/BUILD47
-rw-r--r--tensorflow/core/kernels/decode_proto_op.cc1011
-rw-r--r--tensorflow/core/kernels/encode_proto_op.cc591
-rw-r--r--tensorflow/core/kernels/rpc_op.cc129
-rw-r--r--tensorflow/core/ops/decode_proto_ops.cc67
-rw-r--r--tensorflow/core/ops/encode_proto_ops.cc49
-rw-r--r--tensorflow/core/ops/rpc_ops.cc81
-rw-r--r--tensorflow/core/util/proto/BUILD62
-rw-r--r--tensorflow/core/util/proto/decode.h592
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry.cc45
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry.h76
-rw-r--r--tensorflow/core/util/proto/descriptor_pool_registry_test.cc43
-rw-r--r--tensorflow/core/util/proto/descriptors.cc85
-rw-r--r--tensorflow/core/util/proto/descriptors.h42
-rw-r--r--tensorflow/core/util/proto/local_descriptor_pool_registration.cc39
-rw-r--r--tensorflow/core/util/rpc/BUILD48
-rw-r--r--tensorflow/core/util/rpc/call_container.h90
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.cc53
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h70
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.cc44
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry.h72
-rw-r--r--tensorflow/core/util/rpc/rpc_factory_registry_test.cc41
-rw-r--r--tensorflow/python/BUILD1
45 files changed, 4394 insertions, 0 deletions
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 02c456c199..8e83b4e176 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -354,6 +354,9 @@ tensorflow/contrib/periodic_resample
tensorflow/contrib/periodic_resample/python
tensorflow/contrib/periodic_resample/python/ops
tensorflow/contrib/predictor
+tensorflow/contrib/proto
+tensorflow/contrib/proto/python
+tensorflow/contrib/proto/python/ops
tensorflow/contrib/quantization
tensorflow/contrib/quantization/python
tensorflow/contrib/quantize
@@ -382,6 +385,9 @@ tensorflow/contrib/rnn/ops
tensorflow/contrib/rnn/python
tensorflow/contrib/rnn/python/kernel_tests
tensorflow/contrib/rnn/python/ops
+tensorflow/contrib/rpc
+tensorflow/contrib/rpc/python
+tensorflow/contrib/rpc/python/ops
tensorflow/contrib/saved_model
tensorflow/contrib/saved_model/python
tensorflow/contrib/saved_model/python/saved_model
diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake
index 092a48bc6b..e558691de4 100644
--- a/tensorflow/contrib/cmake/tf_core_ops.cmake
+++ b/tensorflow/contrib/cmake/tf_core_ops.cmake
@@ -25,6 +25,8 @@ set(tf_op_lib_names
"cudnn_rnn_ops"
"data_flow_ops"
"dataset_ops"
+ "decode_proto_ops"
+ "encode_proto_ops"
"functional_ops"
"image_ops"
"io_ops"
@@ -40,6 +42,7 @@ set(tf_op_lib_names
"random_ops"
"remote_fused_graph_ops"
"resource_variable_ops"
+ "rpc_ops"
"script_ops"
"sdca_ops"
"set_ops"
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index fae45ead5c..1a5ec34844 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -330,6 +330,8 @@ GENERATE_PYTHON_OP_LIB("ctc_ops")
GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops")
GENERATE_PYTHON_OP_LIB("data_flow_ops")
GENERATE_PYTHON_OP_LIB("dataset_ops")
+GENERATE_PYTHON_OP_LIB("decode_proto_ops")
+GENERATE_PYTHON_OP_LIB("encode_proto_ops")
GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
@@ -343,6 +345,7 @@ GENERATE_PYTHON_OP_LIB("random_ops")
GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops"
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py)
GENERATE_PYTHON_OP_LIB("resource_variable_ops")
+GENERATE_PYTHON_OP_LIB("rpc_ops")
GENERATE_PYTHON_OP_LIB("script_ops")
GENERATE_PYTHON_OP_LIB("sdca_ops")
GENERATE_PYTHON_OP_LIB("set_ops")
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index b6acf71b9d..0bc4c5d473 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -301,3 +301,5 @@ tensorflow/core/kernels/warn_about_ints.cc
tensorflow/core/kernels/segment_reduction_ops.cc
tensorflow/core/kernels/batch_util.cc
tensorflow/core/ops/audio_ops.cc
+tensorflow/core/kernels/decode_proto_op.cc
+tensorflow/core/kernels/encode_proto_op.cc
diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD
new file mode 100644
index 0000000000..046652cbc5
--- /dev/null
+++ b/tensorflow/contrib/proto/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "proto",
+ srcs = [
+ "__init__.py",
+ ],
+ deps = [
+ "//tensorflow/contrib/proto/python/ops:decode_proto_op_py",
+ "//tensorflow/contrib/proto/python/ops:encode_proto_op_py",
+ ],
+)
diff --git a/tensorflow/contrib/proto/__init__.py b/tensorflow/contrib/proto/__init__.py
new file mode 100644
index 0000000000..bc5a49de78
--- /dev/null
+++ b/tensorflow/contrib/proto/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to proto.
+
+@@decode_proto
+@@encode_proto
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto
+from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/proto/python/ops/BUILD b/tensorflow/contrib/proto/python/ops/BUILD
new file mode 100644
index 0000000000..f17065477e
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/BUILD
@@ -0,0 +1,44 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+)
+
+py_library(
+ name = "decode_proto_op_py",
+ srcs = ["decode_proto_op.py"],
+ deps = [
+ ":gen_decode_proto_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_decode_proto_op_py",
+ out = "gen_decode_proto_op.py",
+ deps = [
+ "//tensorflow/core:decode_proto_ops_op_lib",
+ ],
+)
+
+py_library(
+ name = "encode_proto_op_py",
+ srcs = ["encode_proto_op.py"],
+ deps = [
+ ":gen_encode_proto_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_encode_proto_op_py",
+ out = "gen_encode_proto_op.py",
+ deps = [
+ "//tensorflow/core:encode_proto_ops_op_lib",
+ ],
+)
diff --git a/tensorflow/contrib/proto/python/ops/decode_proto_op.py b/tensorflow/contrib/proto/python/ops/decode_proto_op.py
new file mode 100644
index 0000000000..7dc000ebe4
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/decode_proto_op.py
@@ -0,0 +1,25 @@
+# =============================================================================
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""Protocol Buffer decoding from tensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto
+from tensorflow.python.framework import ops
+ops.NotDifferentiable("DecodeProtoV2")
diff --git a/tensorflow/contrib/proto/python/ops/encode_proto_op.py b/tensorflow/contrib/proto/python/ops/encode_proto_op.py
new file mode 100644
index 0000000000..ac12198b2e
--- /dev/null
+++ b/tensorflow/contrib/proto/python/ops/encode_proto_op.py
@@ -0,0 +1,25 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""Protocol Buffer encoding from tensors."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto
+from tensorflow.python.framework import ops
+
+ops.NotDifferentiable("EncodeProto")
diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD
new file mode 100644
index 0000000000..597f18c771
--- /dev/null
+++ b/tensorflow/contrib/rpc/BUILD
@@ -0,0 +1,13 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "rpc",
+ srcs = [
+ "__init__.py",
+ ],
+ deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"],
+)
diff --git a/tensorflow/contrib/rpc/__init__.py b/tensorflow/contrib/rpc/__init__.py
new file mode 100644
index 0000000000..c65c1a05de
--- /dev/null
+++ b/tensorflow/contrib/rpc/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ops and modules related to RPC.
+
+@@rpc
+@@try_rpc
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rpc.python.ops.rpc_op import rpc
+from tensorflow.contrib.rpc.python.ops.rpc_op import try_rpc
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/rpc/python/ops/BUILD b/tensorflow/contrib/rpc/python/ops/BUILD
new file mode 100644
index 0000000000..84d2a1832f
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/ops/BUILD
@@ -0,0 +1,24 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+
+py_library(
+ name = "rpc_op_py",
+ srcs = ["rpc_op.py"],
+ deps = [
+ ":gen_rpc_op_py",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_rpc_op_py",
+ out = "gen_rpc_op.py",
+ deps = [
+ "//tensorflow/core:rpc_ops_op_lib",
+ ],
+)
diff --git a/tensorflow/contrib/rpc/python/ops/rpc_op.py b/tensorflow/contrib/rpc/python/ops/rpc_op.py
new file mode 100644
index 0000000000..e1b6c41137
--- /dev/null
+++ b/tensorflow/contrib/rpc/python/ops/rpc_op.py
@@ -0,0 +1,26 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+# pylint: disable=wildcard-import,unused-import
+"""RPC communication."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.rpc.python.ops.gen_rpc_op import rpc
+from tensorflow.contrib.rpc.python.ops.gen_rpc_op import try_rpc
+from tensorflow.python.framework import ops
+ops.NotDifferentiable("Rpc")
+ops.NotDifferentiable("TryRpc")
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 7d5ae1c5b5..1eebeb3995 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -637,6 +637,8 @@ tf_gen_op_libs(
"ctc_ops",
"data_flow_ops",
"dataset_ops",
+ "decode_proto_ops",
+ "encode_proto_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -653,6 +655,7 @@ tf_gen_op_libs(
"random_ops",
"remote_fused_graph_ops",
"resource_variable_ops",
+ "rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
"set_ops",
@@ -751,6 +754,8 @@ cc_library(
":cudnn_rnn_ops_op_lib",
":data_flow_ops_op_lib",
":dataset_ops_op_lib",
+ ":decode_proto_ops_op_lib",
+ ":encode_proto_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -767,6 +772,7 @@ cc_library(
":random_ops_op_lib",
":remote_fused_graph_ops_op_lib",
":resource_variable_ops_op_lib",
+ ":rpc_ops_op_lib",
":scoped_allocator_ops_op_lib",
":script_ops_op_lib",
":sdca_ops_op_lib",
@@ -893,6 +899,8 @@ cc_library(
"//tensorflow/core/kernels:cudnn_rnn_kernels",
"//tensorflow/core/kernels:data_flow",
"//tensorflow/core/kernels:dataset_ops",
+ "//tensorflow/core/kernels:decode_proto_op",
+ "//tensorflow/core/kernels:encode_proto_op",
"//tensorflow/core/kernels:fake_quant_ops",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:functional_ops",
@@ -914,6 +922,7 @@ cc_library(
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:required",
"//tensorflow/core/kernels:resource_variable_ops",
+ "//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
"//tensorflow/core/kernels:set_kernels",
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
new file mode 100644
index 0000000000..c8152f53c4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
@@ -0,0 +1,116 @@
+op {
+ graph_op_name: "DecodeProtoV2"
+ in_arg {
+ name: "bytes"
+ description: <<END
+Tensor of serialized protos with shape `batch_shape`.
+END
+ }
+ out_arg {
+ name: "sizes"
+ description: <<END
+Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+Each entry is the number of values found for the corresponding field.
+Optional fields may have 0 or 1 values.
+END
+ }
+ out_arg {
+ name: "values"
+ description: <<END
+List of tensors containing values for the corresponding field.
+`values[i]` has datatype `output_types[i]`
+and shape `[batch_shape, max(sizes[...,i])]`.
+END
+ }
+ attr {
+ name: "message_type"
+ description: <<END
+Name of the proto message type to decode.
+END
+ }
+ attr {
+ name: "field_names"
+ description: <<END
+List of strings containing proto field names.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+List of TF types to use for the respective field in field_names.
+END
+ }
+ attr {
+ name: "descriptor_source"
+ description: <<END
+Either the special value `local://` or a path to a file containing
+a serialized `FileDescriptorSet`.
+END
+ }
+ attr {
+ name: "message_format"
+ description: <<END
+Either `binary` or `text`.
+END
+ }
+ attr {
+ name: "sanitize"
+ description: <<END
+Whether to sanitize the result or not.
+END
+ }
+ summary: <<END
+The op extracts fields from a serialized protocol buffers message into tensors.
+END
+ description: <<END
+The `decode_proto` op extracts fields from a serialized protocol buffers
+message into tensors. The fields in `field_names` are decoded and converted
+to the corresponding `output_types` if possible.
+
+A `message_type` name must be provided to give context for the field
+names. The actual message descriptor can be looked up either in the
+linked-in descriptor pool or a filename provided by the caller using
+the `descriptor_source` attribute.
+
+Each output tensor is a dense tensor. This means that it is padded to
+hold the largest number of repeated elements seen in the input
+minibatch. (The shape is also padded by one to prevent zero-sized
+dimensions). The actual repeat counts for each example in the
+minibatch can be found in the `sizes` output. In many cases the output
+of `decode_proto` is fed immediately into tf.squeeze if missing values
+are not a concern. When using tf.squeeze, always pass the squeeze
+dimension explicitly to avoid surprises.
+
+For the most part, the mapping between Proto field types and
+TensorFlow dtypes is straightforward. However, there are a few
+special cases:
+
+- A proto field that contains a submessage or group can only be converted
+to `DT_STRING` (the serialized submessage). This is to reduce the
+complexity of the API. The resulting string can be used as input
+to another instance of the decode_proto op.
+
+- TensorFlow lacks support for unsigned integers. The ops represent uint64
+types as a `DT_INT64` with the same twos-complement bit pattern
+(the obvious way). Unsigned int32 values can be represented exactly by
+specifying type `DT_INT64`, or using twos-complement if the caller
+specifies `DT_INT32` in the `output_types` attribute.
+
+The `descriptor_source` attribute selects a source of protocol
+descriptors to consult when looking up `message_type`. This may be a
+filename containing a serialized `FileDescriptorSet` message,
+or the special value `local://`, in which case only descriptors linked
+into the code will be searched; the filename can be on any filesystem
+accessible to TensorFlow.
+
+You can build a `descriptor_source` file using the `--descriptor_set_out`
+and `--include_imports` options to the protocol compiler `protoc`.
+
+The `local://` database only covers descriptors linked into the
+code via C++ libraries, not Python imports. You can link in a proto descriptor
+by creating a cc_library target with alwayslink=1.
+
+Both binary and text proto serializations are supported, and can be
+chosen using the `format` attribute.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
new file mode 100644
index 0000000000..fdbe47f236
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt
@@ -0,0 +1,81 @@
+op {
+ graph_op_name: "EncodeProto"
+ in_arg {
+ name: "sizes"
+ description: <<END
+Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+List of tensors containing values for the corresponding field.
+END
+ }
+ out_arg {
+ name: "bytes"
+ description: <<END
+Tensor of serialized protos with shape `batch_shape`.
+END
+ }
+ attr {
+ name: "message_type"
+ description: <<END
+Name of the proto message type to decode.
+END
+ }
+ attr {
+ name: "field_names"
+ description: <<END
+List of strings containing proto field names.
+END
+ }
+ attr {
+ name: "Tinput_types"
+ description: <<END
+The input types.
+END
+ }
+ summary: <<END
+The op serializes protobuf messages provided in the input tensors.
+END
+ description: <<END
+The types of the tensors in `values` must match the schema for the
+fields specified in `field_names`. All the tensors in `values` must
+have a common shape prefix, *batch_shape*.
+
+The `sizes` tensor specifies repeat counts for each field. The repeat
+count (last dimension) of a each tensor in `values` must be greater
+than or equal to corresponding repeat count in `sizes`.
+
+A `message_type` name must be provided to give context for the field
+names. The actual message descriptor can be looked up either in the
+linked-in descriptor pool or a filename provided by the caller using
+the `descriptor_source` attribute.
+
+The `descriptor_source` attribute selects a source of protocol
+descriptors to consult when looking up `message_type`. This may be a
+filename containing a serialized `FileDescriptorSet` message,
+or the special value `local://`, in which case only descriptors linked
+into the code will be searched; the filename can be on any filesystem
+accessible to TensorFlow.
+
+You can build a `descriptor_source` file using the `--descriptor_set_out`
+and `--include_imports` options to the protocol compiler `protoc`.
+
+The `local://` database only covers descriptors linked into the
+code via C++ libraries, not Python imports. You can link in a proto descriptor
+by creating a cc_library target with alwayslink=1.
+
+There are a few special cases in the value mapping:
+
+Submessage and group fields must be pre-serialized as TensorFlow strings.
+
+TensorFlow lacks support for unsigned int64s, so they must be
+represented as `tf.int64` with the same twos-complement bit pattern
+(the obvious way).
+
+Unsigned int32 values can be represented exactly with `tf.int64`, or
+with sign wrapping if the input is of type `tf.int32`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
new file mode 100644
index 0000000000..344ef191fd
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt
@@ -0,0 +1,108 @@
+op {
+ graph_op_name: "Rpc"
+ in_arg {
+ name: "address"
+ description: <<END
+`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `method` and `request`.
+END
+ }
+ in_arg {
+ name: "method"
+ description: <<END
+`0-D` or `1-D`. The method address on the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `request`.
+END
+ }
+ in_arg {
+ name: "request"
+ description: <<END
+`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `method`.
+END
+ }
+ out_arg {
+ name: "response"
+ description: <<END
+Same shape as `request`. Serialized proto strings: the rpc responses.
+END
+ }
+ attr {
+ name: "protocol"
+ description: <<END
+RPC protocol to use. Empty string means use the default protocol.
+Options include 'grpc'.
+END
+ }
+ attr {
+ name: "fail_fast"
+ description: <<END
+`boolean`. If `true` (default), then failures to connect
+(i.e., the server does not immediately respond) cause an RPC failure.
+END
+ }
+ attr {
+ name: "timeout_in_ms"
+ description: <<END
+`int`. If `0` (default), then the kernel will run the RPC
+request and only time out if the RPC deadline passes or the session times out.
+If this value is greater than `0`, then the op will raise an exception if
+the RPC takes longer than `timeout_in_ms`.
+END
+ }
+ summary: <<END
+Perform batches of RPC requests.
+END
+ description: <<END
+This op asynchronously performs either a single RPC request, or a batch
+of requests. RPC requests are defined by three main parameters:
+
+ - `address` (the host+port or BNS address of the request)
+ - `method` (the RPC method name for the request)
+ - `request` (the serialized proto string, or vector of strings,
+ of the RPC request argument).
+
+For example, if you have an RPC service running on port localhost:2345,
+and its interface is configured with the following proto declaration:
+
+```
+service MyService {
+ rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+ }
+};
+```
+
+then call this op with arguments:
+
+```
+address = "localhost:2345"
+method = "MyService/MyMethod"
+```
+
+The `request` tensor is a string tensor representing serialized `MyRequestProto`
+strings; and the output string tensor `response` will have the same shape
+and contain (upon successful completion) corresponding serialized
+`MyResponseProto` strings.
+
+For example, to send a single, empty, `MyRequestProto`, call
+this op with `request = ""`. To send 5 **parallel** empty requests,
+call this op with `request = ["", "", "", "", ""]`.
+
+More generally, one can create a batch of `MyRequestProto` serialized protos
+from regular batched tensors using the `encode_proto` op, and convert
+the response `MyResponseProto` serialized protos to batched tensors
+using the `decode_proto` op.
+
+**NOTE** Working with serialized proto strings is faster than instantiating
+actual proto objects in memory, so no performance degradation is expected
+compared to writing custom kernels for this workflow.
+
+If the connection fails or the remote worker returns an error
+status, the op reraises this exception locally.
+
+See the `TryRpc` op if you prefer to handle RPC failures manually in the graph.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
new file mode 100644
index 0000000000..bded00e83c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt
@@ -0,0 +1,123 @@
+op {
+ graph_op_name: "TryRpc"
+ in_arg {
+ name: "address"
+ description: <<END
+`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `method` and `request`.
+END
+ }
+ in_arg {
+ name: "method"
+ description: <<END
+`0-D` or `1-D`. The method address on the RPC server.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `request`.
+END
+ }
+ in_arg {
+ name: "request"
+ description: <<END
+`0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+If this tensor has more than 1 element, then multiple parallel rpc requests
+are sent. This argument broadcasts with `address` and `method`.
+END
+ }
+ out_arg {
+ name: "response"
+ description: <<END
+Same shape as `request`. Serialized proto strings: the rpc responses.
+END
+ }
+ out_arg {
+ name: "status_code"
+ description: <<END
+Same shape as `request`. Values correspond to tensorflow Status enum codes.
+END
+ }
+ out_arg {
+ name: "status_message"
+ description: <<END
+Same shape as `request`. Values correspond to Status messages
+returned from the RPC calls.
+END
+ }
+ attr {
+ name: "protocol"
+ description: <<END
+RPC protocol to use. Empty string means use the default protocol.
+Options include 'grpc'.
+END
+ }
+ attr {
+ name: "fail_fast"
+ description: <<END
+`boolean`. If `true` (default), then failures to connect
+(i.e., the server does not immediately respond) cause an RPC failure.
+END
+ }
+ attr {
+ name: "timeout_in_ms"
+ description: <<END
+`int`. If `0` (default), then the kernel will run the RPC
+request and only time out if the RPC deadline passes or the session times out.
+If this value is greater than `0`, then the op will raise an exception if
+the RPC takes longer than `timeout_in_ms`.
+END
+ }
+ summary: <<END
+Perform batches of RPC requests.
+END
+ description: <<END
+This op asynchronously performs either a single RPC request, or a batch
+of requests. RPC requests are defined by three main parameters:
+
+ - `address` (the host+port or BNS address of the request)
+ - `method` (the method name for the request)
+ - `request` (the serialized proto string, or vector of strings,
+ of the RPC request argument).
+
+For example, if you have an RPC service running on port localhost:2345,
+and its interface is configured with the following proto declaration:
+
+```
+service MyService {
+ rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+ }
+};
+```
+
+then call this op with arguments:
+
+```
+address = "localhost:2345"
+method = "MyService/MyMethod"
+```
+
+The `request` tensor is a string tensor representing serialized `MyRequestProto`
+strings; and the output string tensor `response` will have the same shape
+and contain (upon successful completion) corresponding serialized
+`MyResponseProto` strings.
+
+For example, to send a single, empty, `MyRequestProto`, call
+this op with `request = ""`. To send 5 **parallel** empty requests,
+call this op with `request = ["", "", "", "", ""]`.
+
+More generally, one can create a batch of `MyRequestProto` serialized protos
+from regular batched tensors using the `encode_proto` op, and convert
+the response `MyResponseProto` serialized protos to batched tensors
+using the `decode_proto` op.
+
+**NOTE** Working with serialized proto strings is faster than instantiating
+actual proto objects in memory, so no performance degradation is expected
+compared to writing custom kernels for this workflow.
+
+Unlike the standard `Rpc` op, if the connection fails or the remote worker
+returns an error status, this op does **not** reraise the exception.
+Instead, the `status_code` and `status_message` entry for the corresponding RPC
+call is set with the error returned from the RPC call. The `response` tensor
+will contain valid response values for those minibatch entries whose RPCs did
+not fail; the rest of the entries will have empty strings.
+END
+}
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index 9c655bfa31..d3478dfc38 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -499,3 +499,33 @@ tf_cuda_cc_test(
"//tensorflow/core/kernels:variable_ops",
],
)
+
+cc_library(
+ name = "grpc_rpc_factory",
+ srcs = [
+ "grpc_rpc_factory.cc",
+ ],
+ hdrs = ["grpc_rpc_factory.h"],
+ deps = [
+ ":grpc_state",
+ ":grpc_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/util/rpc:call_container",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ ],
+)
+
+cc_library(
+ name = "grpc_rpc_factory_registration",
+ srcs = [
+ "grpc_rpc_factory_registration.cc",
+ ],
+ deps = [
+ ":grpc_rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory_registry",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
new file mode 100644
index 0000000000..d004abd1c1
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -0,0 +1,213 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/util/rpc/call_container.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
+
+namespace tensorflow {
+
+namespace {
+class GrpcCall {
+ public:
+ explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
+ const string* request_msg, string* response_msg,
+ int32* status_code, string* status_message)
+ : container_(container),
+ index_(index),
+ try_rpc_(try_rpc),
+ request_msg_(request_msg),
+ response_msg_(response_msg),
+ status_code_(status_code),
+ status_message_(status_message) {}
+
+ void StartCancel() { call_opts_.StartCancel(); }
+
+ void Done(const Status& s) {
+ DCHECK(container_ != nullptr);
+ if (!s.ok() && try_rpc_) {
+ DCHECK(status_code_ != nullptr);
+ DCHECK(status_message_ != nullptr);
+ *status_code_ = s.code();
+ *status_message_ = s.error_message();
+ }
+ container_->Done(s, index_);
+ }
+
+ const string& request() const { return *request_msg_; }
+ string* response() const { return response_msg_; }
+ CallOptions* call_opts() { return &call_opts_; }
+
+ private:
+ CallContainer<GrpcCall>* const container_;
+ const int index_;
+ bool try_rpc_;
+ CallOptions call_opts_;
+ const string* request_msg_;
+ string* response_msg_;
+ int* status_code_;
+ string* status_message_;
+};
+
+} // namespace
+
+GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms)
+ : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) {
+ // TODO(ebrevdo): Investigate possible performance improvements by
+ // replacing this thread with a threadpool.
+ polling_thread_ =
+ ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() {
+ void* tag;
+ bool ok;
+ while (completion_queue_.Next(&tag, &ok)) {
+ GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag);
+ callback_tag->OnCompleted(ok);
+ }
+ });
+}
+
+GrpcRPCFactory::~GrpcRPCFactory() {
+ // The amount of time we wait depends on several parameters, including:
+ // - the value of the fail_fast attribute.
+ // - the timeout option of the rpc call in the proto declaration.
+ // - the network roundtrip time and service's execution time.
+ //
+ // If a connection is made but the service doesn't ever respond, and
+ // there is no timeout option set for this rpc call, then it is
+ // possible the RPC request will wait forever.
+ //
+ completion_queue_.Shutdown();
+ delete polling_thread_;
+}
+
+void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
+ const Tensor& address_t, const Tensor& method_t,
+ const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) {
+ auto address = address_t.flat<string>();
+ auto method = method_t.flat<string>();
+ auto request = request_t.flat<string>();
+
+ // Stubs are maintained by the GrpcRPCFactory class and will be
+ // deleted when the class is destroyed.
+ ::grpc::GenericStub* singleton_stub = nullptr;
+ if (address.size() == 1) {
+ singleton_stub = GetOrCreateStubForAddress(address(0));
+ }
+ auto get_stub = [&address, this,
+ singleton_stub](int64 ix) -> ::grpc::GenericStub* {
+ return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
+ : singleton_stub;
+ };
+ auto get_method_ptr = [&method](int64 ix) -> const string* {
+ return (method.size() > 1) ? &(method(ix)) : &(method(0));
+ };
+ auto get_request_ptr = [&request](int64 ix) -> const string* {
+ return (request.size() > 1) ? &(request(ix)) : &(request(0));
+ };
+
+ if (try_rpc) {
+ // In this case status_code will never be set in the response,
+ // so we just set it to OK.
+ DCHECK(status_code_t != nullptr);
+ status_code_t->flat<int32>().setConstant(
+ static_cast<int>(errors::Code::OK));
+ }
+
+ CancellationManager* cm = ctx->cancellation_manager();
+ CancellationToken cancellation_token = cm->get_cancellation_token();
+
+ // This object will delete itself when done.
+ auto* container =
+ new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
+ std::move(done), cancellation_token);
+
+ auto response = response_t->flat<string>();
+ int32* status_code_ptr = nullptr;
+ string* status_message_ptr = nullptr;
+ if (try_rpc) {
+ status_code_ptr = status_code_t->flat<int32>().data();
+ status_message_ptr = status_message_t->flat<string>().data();
+ }
+ for (int i = 0; i < num_elements; ++i) {
+ container->calls()->emplace_back(
+ container, i, try_rpc, get_request_ptr(i), &response(i),
+ (try_rpc) ? &status_code_ptr[i] : nullptr,
+ (try_rpc) ? &status_message_ptr[i] : nullptr);
+ }
+
+ int i = 0;
+ for (GrpcCall& call : *(container->calls())) {
+ // This object will delete itself when done.
+ new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
+ call.request(), call.response(),
+ /*done=*/[&call](const Status& s) { call.Done(s); },
+ call.call_opts(), fail_fast_, timeout_in_ms_);
+ ++i;
+ }
+
+ // Need to register this callback after all the RPCs are in
+ // flight; otherwise we may try to cancel an RPC *before* it
+ // launches, which is a no-op, and then fall into a deadlock.
+ bool is_cancelled = !cm->RegisterCallback(
+ cancellation_token, [container]() { container->StartCancel(); });
+
+ if (is_cancelled) {
+ ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
+ // container's reference counter will take care of calling done().
+ container->StartCancel();
+ }
+}
+
+::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
+ const string& address) {
+ mutex_lock lock(mu_);
+
+ auto stub = stubs_.find(address);
+ if (stub != stubs_.end()) return stub->second.get();
+
+ ChannelPtr channel = CreateChannelForAddress(address);
+ auto* created = new ::grpc::GenericStub(channel);
+ stubs_[address].reset(created);
+ return created;
+}
+
+GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
+ const string& address) {
+ ::grpc::ChannelArguments args;
+ args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
+
+ // Set a standard backoff timeout of 1s instead of the
+ // (sometimes default) 20s.
+ args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
+ return ::grpc::CreateCustomChannel(
+ /*target=*/address, ::grpc::InsecureChannelCredentials(), args);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
new file mode 100644
index 0000000000..34ec235aaf
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+class GrpcRPCFactory : public RPCFactory {
+ public:
+ explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms);
+
+ // Explicit destructor to control destruction order.
+ ~GrpcRPCFactory() override;
+
+ void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t,
+ const Tensor& method_t, const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) override;
+
+ protected:
+ typedef std::shared_ptr<::grpc::Channel> ChannelPtr;
+ virtual ChannelPtr CreateChannelForAddress(const string& address);
+
+ private:
+ ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
+
+ bool fail_fast_;
+ int64 timeout_in_ms_;
+ ::grpc::CompletionQueue completion_queue_;
+ Thread* polling_thread_; // Owned.
+
+ mutex mu_;
+ typedef std::unique_ptr<::grpc::GenericStub> StubPtr;
+ std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc
new file mode 100644
index 0000000000..b884489378
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc
@@ -0,0 +1,34 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+namespace {
+
+// Used for adding the grpc factory to the RPC factory registry.
+struct Value {
+ static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms) {
+ return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms);
+ }
+};
+
+REGISTER_RPC_FACTORY("grpc", Value::Function);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1857d8d655..783de6af88 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5121,6 +5121,9 @@ filegroup(
"summary_interface.*",
"summary_kernels.*",
"spectrogram_convert_test_data.cc",
+ "decode_proto_op.cc",
+ "encode_proto_op.cc",
+ "rpc_op.cc",
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",
@@ -6153,6 +6156,50 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "decode_proto_op",
+ srcs = [
+ "decode_proto_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:decode_proto_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/util/proto:decode",
+ "//tensorflow/core/util/proto:descriptors",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "encode_proto_op",
+ srcs = ["encode_proto_op.cc"],
+ deps = [
+ "//tensorflow/core:encode_proto_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/util/proto:descriptors",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "rpc_op",
+ srcs = [
+ "rpc_op.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:rpc_ops_op_lib",
+ "//tensorflow/core/util/rpc:call_container",
+ "//tensorflow/core/util/rpc:rpc_factory",
+ "//tensorflow/core/util/rpc:rpc_factory_registry",
+ "//third_party/eigen3",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
new file mode 100644
index 0000000000..b4e5b776ed
--- /dev/null
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -0,0 +1,1011 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// DecodeProto is a TensorFlow Op which extracts arbitrary fields
+// from protos serialized as strings.
+//
+// See docs in ../ops/decode_proto_op.cc.
+//
+// This implementation reads the serialized format using a handful of
+// calls from the WireFormatLite API used by generated proto code.
+// WireFormatLite is marked as an "internal" proto API but is widely
+// used in practice and highly unlikely to change.
+// This will be much faster than the previous implementation based on
+// constructing a temporary dynamic message in memory and using the
+// proto reflection api to read it.
+// It can be used with any proto whose descriptors are available at
+// runtime but should be competitive in speed with approaches that
+// compile in the proto definitions.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/decode.h"
+#include "tensorflow/core/util/proto/descriptors.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::MakeUnique;
+using ::tensorflow::protobuf::Descriptor;
+using ::tensorflow::protobuf::DescriptorPool;
+using ::tensorflow::protobuf::DynamicMessageFactory;
+using ::tensorflow::protobuf::FieldDescriptor;
+using ::tensorflow::protobuf::Message;
+using ::tensorflow::protobuf::TextFormat;
+using ::tensorflow::protobuf::internal::WireFormatLite;
+using ::tensorflow::protobuf::io::CodedInputStream;
+
+const bool kFailOnDecodeError = true;
+
+// Returns true if the proto field type can be converted to the
+// tensorflow::DataType.
+bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) {
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return output_type == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_FLOAT:
+ return output_type == tensorflow::DT_FLOAT ||
+ output_type == tensorflow::DT_DOUBLE;
+ case WireFormatLite::TYPE_INT64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_UINT64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_INT32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_FIXED64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_FIXED32:
+ return output_type == tensorflow::DT_INT32 ||
+ output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_BOOL:
+ return output_type == tensorflow::DT_BOOL;
+ case WireFormatLite::TYPE_STRING:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_GROUP:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_MESSAGE:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_BYTES:
+ return output_type == tensorflow::DT_STRING;
+ case WireFormatLite::TYPE_UINT32:
+ return output_type == tensorflow::DT_INT32 ||
+ output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_ENUM:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SFIXED32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SFIXED64:
+ return output_type == tensorflow::DT_INT64;
+ case WireFormatLite::TYPE_SINT32:
+ return output_type == tensorflow::DT_INT32;
+ case WireFormatLite::TYPE_SINT64:
+ return output_type == tensorflow::DT_INT64;
+ // default: intentionally omitted in order to enable static checking.
+ }
+}
+
+// A FieldInfo holds a handful of information from the FieldDescriptor
+// and user attributes.
+struct FieldInfo {
+ FieldInfo(const FieldDescriptor* field_desc, int user_index)
+ : output_index(user_index) {
+ // Without this intermediate data structure, the profile had hotspots
+ // calling methods of FieldDescriptor.
+ number = field_desc->number();
+
+ // The wire format library defines the same constants used in
+ // descriptor.proto. This static_cast is safe because they
+ // are guaranteed to stay in sync.
+ // We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about
+ // what happens inside a packed repeated field: there is
+ // enough information in the wire format to skip the
+ // whole field but not enough to know how to parse what's
+ // inside. For that we go to the schema.
+ type = static_cast<WireFormatLite::FieldType>(field_desc->type());
+ is_repeated = field_desc->is_repeated();
+ }
+
+ // Disable copy and move.
+ FieldInfo(const FieldInfo&) = delete;
+ FieldInfo& operator=(const FieldInfo&) = delete;
+
+ // Internally we sort field descriptors by wire number for
+ // fast lookup. In general this is different from the order
+ // given by the user. Output_index gives the index into
+ // the field_names and output_types attributes and into
+ // the output tensor list.
+ int output_index = -1;
+
+ // This is a cache of the relevant fields from `FieldDescriptorProto`.
+ // This was added after noticing that FieldDescriptor->type() was
+ // using 6% of the cpu profile.
+ WireFormatLite::FieldType type;
+ int number;
+ bool is_repeated;
+};
+
+// A CountCollector counts sizes of repeated and optional fields in a proto.
+//
+// Each field is tracked by a single CountCollector instance. The
+// instance manages a single count, which is stored as a pointer (it
+// is intended to be a reference to the `sizes` output which is being
+// filled in). The pointer is passed in at initialization.
+//
+// Counting is done as a separate pass in order to allocate output tensors
+// all at once. This allows the TensorFlow runtime to optimize allocation
+// for the consumer, while removing the need for copying inside this op.
+// After this pass, the DenseCollector class (below) gathers the data:
+// It is more complex and provides better motivation for the API here.
+class CountCollector {
+ public:
+ // Default constructor allows the collector to be a vector element.
+ CountCollector() = default;
+
+ // The count may be stored inside an Eigen Tensor to eliminate copying.
+ explicit CountCollector(int32* count) : count_ptr_(count) {}
+
+ // Reads (in this case counts) a single value.
+ Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
+ // Only repeated fields can have count > 1.
+ if (*count_ptr_ == 0 || field.is_repeated) {
+ (*count_ptr_)++;
+ }
+ // We expect a wire type based on the schema field_type, to allow
+ // a little more checking.
+ if (!SkipValue(input, field)) {
+ return errors::DataLoss("ReadValue: Failed skipping field when counting");
+ }
+ return Status::OK();
+ }
+
+ // Reads (in this case counts) a length-delimited list of values.
+ Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
+ size_t buf_size) {
+ if (buf_size == 0) {
+ return Status::OK();
+ }
+
+ const void* tmpbuf;
+ int unused_max_buf_size;
+
+ input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
+ // This is safe because the underlying storage for the CodedInputStream is
+ // owned by the input tensor. If it were a Cord or file-backed stream this
+ // pointer would go stale after the bytes were skipped.
+ const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
+
+ // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
+ // so the bounds check on buf_size inside Skip() is critical, and
+ // must be done before scanning the contents.
+ if (!input->Skip(buf_size)) {
+ return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
+ }
+
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ Status st;
+ switch (field.type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ st = CountPackedFixed<double>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FLOAT:
+ st = CountPackedFixed<float>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_INT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_UINT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_INT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FIXED64:
+ st = CountPackedFixed<uint64>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_FIXED32:
+ st = CountPackedFixed<uint32>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_BOOL:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_STRING:
+ st = errors::DataLoss("TYPE_STRING encountered as packed");
+ break;
+ case WireFormatLite::TYPE_GROUP:
+ st = errors::DataLoss("TYPE_GROUP encountered as packed");
+ break;
+ case WireFormatLite::TYPE_MESSAGE:
+ st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
+ break;
+ case WireFormatLite::TYPE_BYTES:
+ st = errors::DataLoss("TYPE_BYTES encountered as packed");
+ break;
+ case WireFormatLite::TYPE_UINT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_ENUM:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SFIXED32:
+ st = CountPackedFixed<int32>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SFIXED64:
+ st = CountPackedFixed<int64>(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SINT32:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ case WireFormatLite::TYPE_SINT64:
+ st = CountPackedVarint(buf, buf_size);
+ break;
+ // default: intentionally omitted in order to enable static checking.
+ }
+ if (!st.ok()) {
+ return st;
+ }
+
+ if (!field.is_repeated && *count_ptr_ > 1) {
+ *count_ptr_ = 1;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Skips a length-delimited value.
+ static bool SkipBytes(CodedInputStream* input) {
+ uint32 length;
+ if (!input->ReadVarint32(&length)) {
+ return false;
+ }
+ return input->Skip(length);
+ }
+
+ // Counts the number of packed varints in an array.
+ // The end of a varint is signaled by a value < 0x80,
+ // so counting them requires parsing the bytestream.
+ // It is the caller's responsibility to ensure that len > 0.
+ Status CountPackedVarint(const uint8* buf, size_t len) {
+ const uint8* bound = buf + len;
+ int count;
+
+ // The last byte in a valid encoded varint is guaranteed to have
+ // the high bit unset. We rely on this property to prevent
+ // ReadVarint64FromArray from going out of bounds, so validate
+ // the end of the buf before scanning anything.
+ if (bound[-1] & 0x80) {
+ return errors::DataLoss("Corrupt packed varint");
+ }
+
+ // Now we can trust ReadVarint64FromArray to stay in bounds.
+ for (count = 0; buf < bound; ++count) {
+ uint64 temp;
+ bool ok;
+ buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
+ if (!ok) {
+ return errors::DataLoss("Corrupt packed varint");
+ }
+ }
+
+ *count_ptr_ += count;
+ return Status::OK();
+ }
+
+ // Counts the number of fixed-size values in a packed field.
+ // This can be done without actually parsing anything.
+ template <typename T>
+ Status CountPackedFixed(const uint8* unused_buf, size_t len) {
+ int count = len / sizeof(T);
+ if (count * sizeof(T) != len) {
+ return errors::DataLoss(
+ "Illegal data length for packed fixed-size type: ", len);
+ }
+ *count_ptr_ += len / sizeof(T);
+ return Status::OK();
+ }
+
+ // Skips a single value in the input stream.
+ // Dispatches to the appropriately typed field skipper based on the
+ // schema type tag.
+ // This is not as permissive as just handling the wire type.
+ static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
+ uint32 tmp32;
+ protobuf_uint64 tmp64;
+ switch (field.type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_FLOAT:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_INT64:
+ return input->ReadVarint64(&tmp64);
+ case WireFormatLite::TYPE_UINT64:
+ return input->ReadVarint64(&tmp64);
+ case WireFormatLite::TYPE_INT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_FIXED64:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_FIXED32:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_BOOL:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_STRING:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_GROUP:
+ return WireFormatLite::SkipField(
+ input, WireFormatLite::MakeTag(
+ field.number, WireFormatLite::WIRETYPE_START_GROUP));
+ case WireFormatLite::TYPE_MESSAGE:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_BYTES:
+ return SkipBytes(input);
+ case WireFormatLite::TYPE_UINT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_ENUM:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_SFIXED32:
+ return input->ReadLittleEndian32(&tmp32);
+ case WireFormatLite::TYPE_SFIXED64:
+ return input->ReadLittleEndian64(&tmp64);
+ case WireFormatLite::TYPE_SINT32:
+ return input->ReadVarint32(&tmp32);
+ case WireFormatLite::TYPE_SINT64:
+ return input->ReadVarint64(&tmp64);
+ // default: intentionally omitted in order to enable static checking.
+ }
+ }
+
+ int32* count_ptr_ = nullptr;
+};
+
+// A DenseCollector accumulates values from a proto into a tensor.
+//
+// There is an instance of DenseCollector for each field of each
+// proto. The DenseCollector deserializes the value from the wire
+// directly into the preallocated output Tensor.
+//
+// This class is named DenseCollector because in the future there should
+// be a SparseCollector that accumulates field data into sparse tensors if
+// the user requests it.
+class DenseCollector {
+ public:
+ // Default constructor allows the collector to be a vector element.
+ DenseCollector() = default;
+
+ // A DenseCollector applies to one field of a serialized message.
+ DenseCollector(uint8* datap, DataType dtype, int max_repeat_count)
+ : datap_(datap), dtype_(dtype), max_repeat_count_(max_repeat_count) {}
+
+ // Reads a value from the input stream and stores it.
+ //
+ // Always inlining gave a ~50% speedup on microbenchmarks at one point.
+ // TODO(nix): try removing it to see if that still holds.
+ // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
+ Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
+ // For required and optional fields, we overwrite values[0] with
+ // the latest one in the wire stream.
+ // See https://developers.google.com/protocol-buffers/docs/encoding#optional
+ // Only for repeated fields do we advance the next_repeat_index_ past 1.
+ // TODO(nix): to handle oneof we must also zero out any previous values
+ // seen on the wire.
+ int32 index = 0;
+ if (field.is_repeated) {
+ index = next_repeat_index_;
+ }
+ next_repeat_index_ = index + 1;
+
+ return internal::ReadValue(input, field.type, field.number, dtype_, index,
+ datap_);
+ }
+
+ // Reads and stores a length-delimited list of values.
+ Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
+ const size_t buf_size) {
+ const void* buf;
+ int unused_max_buf_size;
+ input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
+ // This is safe because the underlying storage for the CodedInputStream is
+ // owned by the input tensor. If it were a Cord or file-backed stream this
+ // pointer would go stale after the bytes were skipped.
+ if (!input->Skip(buf_size)) {
+ return errors::DataLoss(
+ "ReadPackedValues: Skipping packed field failed. Field tag: ",
+ field.number);
+ }
+
+ // Setting stride=0 causes new values to overwrite old ones for
+ // non-repeated fields.
+ const int stride = field.is_repeated ? 1 : 0;
+
+ if (next_repeat_index_ >= max_repeat_count_) {
+ return errors::DataLoss(
+ "ReadPackedValues: Tried to write more entries than allowed. "
+ "Field tag: ",
+ field.number, ", Max entries allowed: ", max_repeat_count_);
+ } else {
+ return internal::ReadPackedFromArray(buf, buf_size, field.type,
+ field.number, dtype_, stride,
+ &next_repeat_index_, datap_);
+ }
+ }
+
+ // Fills in any missing values in the output array with defaults.
+ // Dispatches to the appropriately typed field default based on the
+ // runtime type tag.
+ Status FillWithDefaults() {
+ switch (dtype_) {
+ case DataType::DT_FLOAT:
+ return FillDefault<float>();
+ case DataType::DT_DOUBLE:
+ return FillDefault<double>();
+ case DataType::DT_INT32:
+ return FillDefault<int32>();
+ case DataType::DT_UINT8:
+ return FillDefault<uint8>();
+ case DataType::DT_INT8:
+ return FillDefault<int8>();
+ case DataType::DT_STRING:
+ return FillDefault<string>();
+ case DataType::DT_INT64:
+ return FillDefault<int64>();
+ case DataType::DT_BOOL:
+ return FillDefault<bool>();
+ default:
+ // There are many tensorflow dtypes not handled here, but they
+ // should not come up unless type casting is added to the Op.
+ // Chaining with tf.cast() should do the right thing until then.
+ return errors::DataLoss(
+ "Failed filling defaults in unknown tf::DataType");
+ }
+ }
+
+ private:
+ // Fills empty values in the dense representation with a
+ // default value. This uses next_repeat_index_ which counts the number
+ // of parsed values for the field.
+ template <class T>
+ Status FillDefault() {
+ for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
+ reinterpret_cast<T*>(datap_)[i] = T();
+ }
+ return Status::OK();
+ }
+
+ int32 next_repeat_index_ = 0;
+
+ // This is a pointer to data_[message_index_].
+ // There is no bounds checking at this level: we computed the max
+ // repeat size for each field in CountCollector and use the same
+ // code to traverse it here, so we are guaranteed not to be called
+ // for more items than we have allocated space.
+ void* const datap_ = nullptr;
+
+ const DataType dtype_ = DataType::DT_INVALID;
+ const int max_repeat_count_ = 0;
+};
+
+class DecodeProtoOp : public OpKernel {
+ public:
+ explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
+ string descriptor_source;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("descriptor_source", &descriptor_source));
+
+ // We always get back a desc_pool, but we may not own it. If we own it,
+ // owned_desc_pool_ will be filled in.
+ DescriptorPool const* desc_pool;
+ OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
+ &desc_pool, &owned_desc_pool_));
+
+ string message_type;
+ OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
+
+ const Descriptor* message_desc =
+ desc_pool->FindMessageTypeByName(message_type);
+ OP_REQUIRES(context, message_desc != nullptr,
+ errors::InvalidArgument("No descriptor found for message type ",
+ message_type));
+
+ std::vector<string> field_names;
+ OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
+ std::vector<DataType> output_types;
+ OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
+ OP_REQUIRES(
+ context, field_names.size() == output_types.size(),
+ errors::InvalidArgument("field_names and output_types attributes must "
+ "have the same length"));
+
+ // Gather the field descriptors and check that requested output types match.
+
+ int field_index = 0;
+ std::vector<const FieldDescriptor*> field_descs;
+ for (const string& name : field_names) {
+ auto fd = message_desc->FindFieldByName(name);
+ OP_REQUIRES(context, fd != nullptr,
+ errors::InvalidArgument("Unknown field: ", name,
+ " in message type ", message_type));
+ OP_REQUIRES(context,
+ CheckOutputType(fd->type(), output_types[field_index]),
+ // Many TensorFlow types don't have corresponding proto types
+ // and the user will get an error if they are requested. It
+ // would be nice to allow conversions here, but tf.cast
+ // already exists so we don't duplicate the functionality.
+ // Known unhandled types:
+ // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
+ // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
+ errors::InvalidArgument("Unexpected output type for ",
+ fd->full_name(), ": ", fd->cpp_type(),
+ " to ", output_types[field_index]));
+
+ field_index++;
+ field_descs.push_back(fd);
+ }
+
+ // Internally we want the field_descs sorted by their number on the wire.
+ // But the output tensors are allocated in the order given by the caller.
+ // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
+ std::vector<int> output_indices;
+ output_indices.reserve(field_names.size());
+ for (int i = 0; i < field_names.size(); i++) {
+ output_indices.push_back(i);
+ }
+ std::sort(output_indices.begin(), output_indices.end(),
+ [field_descs](int a, int b) {
+ return field_descs[a]->number() < field_descs[b]->number();
+ });
+
+ // Now store the fields in sorted order.
+ for (int i = 0; i < field_names.size(); i++) {
+ fields_.push_back(MakeUnique<FieldInfo>(field_descs[output_indices[i]],
+ output_indices[i]));
+ }
+
+ message_prototype_ = message_factory_.GetPrototype(message_desc);
+ OP_REQUIRES(context, message_prototype_ != nullptr,
+ errors::InvalidArgument("Couldn't get prototype message: ",
+ message_desc->full_name()));
+ string format;
+ OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
+ OP_REQUIRES(
+ context, format == "binary" || format == "text",
+ errors::InvalidArgument("format must be one of binary or text"));
+ is_binary_ = format == "binary";
+
+ // Enable the initial protobuf sanitizer, which is much
+ // more expensive than the decoder.
+ // TODO(nix): Remove this once the fast decoder
+ // has passed security review.
+ OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& buf_tensor = ctx->input(0);
+ int message_count = buf_tensor.NumElements();
+ OP_REQUIRES(ctx, message_count >= 1,
+ errors::InvalidArgument(
+ "Bufs argument must contain at least one value"));
+
+ int field_count = fields_.size();
+
+ // Save the argument shape for later, then flatten the input
+ // Tensor since we are working componentwise. We will restore
+ // the same shape in the returned Tensor.
+ const TensorShape& shape_prefix = buf_tensor.shape();
+
+ TensorShape sizes_shape = shape_prefix;
+ sizes_shape.AddDim(field_count);
+ Tensor* sizes_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
+
+ // This is used to allocate binary bufs if used. It serves only
+ // to define memory ownership.
+ std::vector<string> tmp_binary_bufs(message_count);
+
+ // These are the actual buffers to use, which may be in tmp_binary_bufs
+ // or may be pointers into the buf_tensor. Either way they are not owned
+ // here.
+ std::vector<const string*> bufs;
+
+ if (is_binary_ && !sanitize_) {
+ // Fast path.
+ for (int mi = 0; mi < message_count; ++mi) {
+ const string* buf = &buf_tensor.flat<string>()(mi);
+ bufs.push_back(buf);
+ }
+ } else {
+ // We will have to allocate a copy, either to convert from text to
+ // binary or to sanitize a binary proto.
+ for (int mi = 0; mi < message_count; ++mi) {
+ ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
+ &tmp_binary_bufs[mi]);
+ if (!ctx->status().ok()) {
+ return;
+ }
+ bufs.push_back(&tmp_binary_bufs[mi]);
+ }
+ }
+
+ // Walk through all the strings in the input tensor, counting
+ // the number of fields in each.
+ // We can't allocate our actual output Tensor until we know the
+ // maximum repeat count, so we do a first pass through the serialized
+ // proto just counting fields.
+ // We always allocate at least one value so that optional fields
+ // are populated with default values - this avoids a TF
+ // conditional when handling the output data.
+ // The caller can distinguish between real data and defaults
+ // using the repeat count matrix that is returned by decode_proto.
+ std::vector<int32> max_sizes(field_count, 1);
+ for (int mi = 0; mi < message_count; ++mi) {
+ CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
+ if (!ctx->status().ok()) {
+ return;
+ }
+ }
+
+ // Allocate the output tensors now that we've seen the max size.
+ // TODO(nix): Use allocate_output_or_forward_input for the largest
+ // output tensor. This can avoid one large allocation by re-using
+ // the memory of the input tensor.
+ std::vector<Tensor*> outputs(field_count);
+ for (int fi = 0; fi < field_count; ++fi) {
+ TensorShape flat_shape = {static_cast<int64>(message_count),
+ max_sizes[fi]};
+ TensorShape out_shape = shape_prefix;
+ out_shape.AddDim(max_sizes[fi]);
+
+ // Surprisingly we don't specify the types from the output_types
+ // attribute: that is done for us based on the Op declaration:
+ // REGISTER_OP(...)
+ // .Attr("output_types: list(type) >= 0")
+ // .Output("values: output_types")
+ OP_REQUIRES_OK(ctx,
+ // ctx->allocate_output(output_indices_[fi] + 1,
+ ctx->allocate_output(fields_[fi]->output_index + 1,
+ out_shape, &outputs[fi]));
+ }
+
+ // Make the second pass through the serialized proto, decoding
+ // into preallocated tensors.
+ AccumulateFields(ctx, bufs, outputs);
+ }
+
+ private:
+ // Copy a serialized message to binary, e.g. to handle text proto inputs.
+ void ReserializeMessage(OpKernelContext* ctx, const string& buf,
+ string* binary_buf) {
+ // Handle text protos by translating them to binary.
+ std::unique_ptr<Message> message(message_prototype_->New());
+ OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
+
+ if (is_binary_) {
+ // If we get here we are sanitizing the input protobuf by parsing
+ // and reserializing it with a trusted (but very slow) library.
+ OP_REQUIRES(ctx, message->ParseFromString(buf),
+ errors::DataLoss("Unable to parse binary protobuf"));
+ } else {
+ OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
+ errors::DataLoss("Unable to parse text protobuf"));
+ }
+
+ OP_REQUIRES(ctx, message->SerializeToString(binary_buf),
+ errors::DataLoss("Unable to reserialize text proto as binary"));
+ }
+
+ // Count the number of occurrences of each requested field in a message batch.
+ void CountFields(OpKernelContext* ctx, int message_index, const string& buf,
+ Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
+ int field_count = fields_.size();
+
+ CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
+ buf.size());
+
+ std::vector<int32> field_sizes(field_count, 0);
+ std::vector<CountCollector> counters;
+ counters.reserve(field_count);
+ for (int i = 0; i < field_count; i++) {
+ counters.emplace_back(&field_sizes[i]);
+ }
+
+ Status st = Collect(&input, &counters);
+ if (st.ok() && !input.ConsumedEntireMessage()) {
+ st = errors::DataLoss("CountFields: Failed to consume entire buffer");
+ }
+ if (kFailOnDecodeError) {
+ OP_REQUIRES_OK(ctx, st); // NOLINT
+ }
+ if (!st.ok()) {
+ // This code suppresses the corrupt proto, treating it as empty
+ // to avoid crashing the process.
+ LOG(WARNING) << "Proto counting error for message type " << message_type_
+ << ": " << st;
+
+ for (int fi = 0; fi < field_count; fi++) {
+ field_sizes[fi] = 0;
+ }
+ // Finished decoding this message.
+ return;
+ }
+
+ // Update the size tensor and max repeat size for each field.
+ auto sizes = sizes_tensor->flat_inner_dims<int32>();
+ for (int fi = 0; fi < field_count; fi++) {
+ int32 size = field_sizes[fi];
+ sizes(message_index, fields_[fi]->output_index) = size;
+ if ((*max_sizes)[fi] < size) {
+ (*max_sizes)[fi] = size;
+ }
+ }
+ }
+
+ // Parse fields from a serialized message into preallocated tensors.
+ void AccumulateFields(OpKernelContext* ctx,
+ const std::vector<const string*>& bufs,
+ std::vector<Tensor*> outputs) {
+ struct TensorInfo {
+ explicit TensorInfo(Tensor* tensor) {
+ // Note that we can decode only max_repeat_count values before overflow.
+ // No other bounds checking is done for repeated fields. For
+ // optional fields there is a check to make sure that only the last
+ // value on the wire appears in the output tensor.
+ dtype = tensor->dtype();
+ last_dim_size = tensor->dim_size(tensor->dims() - 1);
+
+ if (dtype != DT_STRING) {
+ const int element_size = DataTypeSize(dtype);
+ CHECK_GT(element_size, 0);
+ stride = last_dim_size * element_size;
+
+ const int64 flatshape[1] = {tensor->NumElements() * element_size};
+ data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
+ } else {
+ // DataTypeSize() returns 0 for string types.
+ stride = last_dim_size * sizeof(string);
+ data = reinterpret_cast<uint8*>(tensor->flat<string>().data());
+ }
+ }
+
+ DataType dtype;
+ int last_dim_size;
+ int stride;
+ uint8* data;
+ };
+
+ int field_count = fields_.size();
+
+ std::vector<TensorInfo> tensors;
+ tensors.reserve(field_count);
+ for (int fi = 0; fi < field_count; fi++) {
+ tensors.emplace_back(outputs[fi]);
+ }
+
+ for (int message_index = 0; message_index < bufs.size(); ++message_index) {
+ const string& buf = *bufs[message_index];
+
+ std::vector<DenseCollector> collectors;
+ collectors.reserve(field_count);
+ for (const TensorInfo& info : tensors) {
+ collectors.emplace_back(info.data + message_index * info.stride,
+ info.dtype, info.last_dim_size);
+ }
+
+ // Fill in output tensors from the wire.
+ CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
+ buf.size());
+ Status st = Collect(&input, &collectors);
+ if (st.ok() && !input.ConsumedEntireMessage()) {
+ st = errors::DataLoss(
+ "AccumulateFields: Failed to consume entire buffer");
+ }
+ if (kFailOnDecodeError) {
+ OP_REQUIRES_OK(ctx, st); // NOLINT
+ }
+ if (!st.ok()) {
+ // This code suppresses the corrupt proto, treating it as empty
+ // to avoid crashing training.
+ LOG(WARNING) << "Proto counting error for message type "
+ << message_type_ << ": " << st;
+ }
+
+ // Fill the remainder of the dense outputs with default values.
+ for (auto& collector : collectors) {
+ OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
+ }
+ }
+ }
+
+ // Look up the FieldDescriptor for a particular field number.
+ bool LookupField(int field_number, int* field_index) {
+ // Look up the FieldDescriptor using linear search.
+ // TODO(nix): this could be sped up with binary search, but we are
+ // already way off the fastpath at this point. If you see a hotspot
+ // here, somebody is sending you very inefficient protos.
+ for (int fi = fields_.size() - 1; fi >= 0; fi--) {
+ if (field_number == fields_[fi]->number) {
+ *field_index = fi;
+ return true;
+ }
+ }
+ return false;
+ }
+
+ // Traverses a serialized protobuf, dispatching values to the collectors.
+ template <class CollectorClass>
+ Status Collect(CodedInputStream* input,
+ std::vector<CollectorClass>* collectors) {
+ int last_good_field_index = -1;
+ bool fields_disordered = false;
+ int prev_field_number = -1;
+ int field_number = -1;
+ int last_good_field_number = -1;
+ int next_good_field_number = fields_[0]->number;
+
+ // The 'tag' variable should always be treated as tainted.
+ for (uint32 tag = input->ReadTag();
+ tag != 0 && WireFormatLite::GetTagWireType(tag) !=
+ WireFormatLite::WIRETYPE_END_GROUP;
+ tag = input->ReadTag(), prev_field_number = field_number) {
+ field_number = WireFormatLite::GetTagFieldNumber(tag);
+ const FieldInfo* field = nullptr;
+
+ // This takes advantage of the sorted field numbers in most serialized
+ // protos: it tries the next expected field first rather than doing
+ // a lookup by field number.
+ // TODO(nix): haberman@ suggests a hybrid approach with a lookup table
+ // for small field numbers and a hash table for larger ones. This would
+ // be a simpler approach that should offer comparable speed in most
+ // cases.
+ if (field_number == last_good_field_number) {
+ field = fields_[last_good_field_index].get();
+ } else {
+ if (field_number < prev_field_number) {
+ fields_disordered = true;
+ }
+
+ // If fields are out of order, fall back to slow lookup.
+ if (fields_disordered) {
+ int field_index;
+ if (LookupField(field_number, &field_index)) {
+ field = fields_[field_index].get();
+ last_good_field_index = field_index;
+ }
+ } else {
+ // If we see a field that is past the next field we want,
+ // it was empty. Look for the one after that.
+ // Repeat until we run out of fields that we care about.
+ while (field_number >= next_good_field_number) {
+ if (field_number == next_good_field_number) {
+ last_good_field_number = field_number;
+ field = fields_[last_good_field_index + 1].get();
+ }
+
+ // Start looking for the field after the current one.
+ ++last_good_field_index;
+ if (last_good_field_index < fields_.size() - 1) {
+ next_good_field_number =
+ fields_[last_good_field_index + 1]->number;
+ } else {
+ // Saw something past the last field we care about.
+ // Continue parsing the message just in case there
+ // are disordered fields later, but any remaining
+ // ordered fields will have no effect.
+ next_good_field_number = INT_MAX;
+ }
+ }
+ }
+ }
+
+ if (!field) {
+ // Unknown and unrequested fields are skipped.
+ if (!WireFormatLite::SkipField(input, tag)) {
+ return errors::DataLoss("Failed skipping unrequested field");
+ }
+ continue;
+ }
+
+ Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag),
+ input, &(*collectors)[last_good_field_index]);
+ if (!st.ok()) {
+ return st;
+ }
+ }
+ return Status::OK();
+ }
+
+ // Collects values for a single field.
+ template <class CollectorClass>
+ Status CollectField(const FieldInfo& field,
+ WireFormatLite::WireType wire_type,
+ CodedInputStream* input, CollectorClass* collector) {
+ // The wire format library defines the same constants used in
+ // descriptor.proto. This static_cast is safe because they
+ // are guaranteed to stay in sync.
+ // We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about
+ // what happens inside a packed repeated field: there is
+ // enough information in the wire format to skip the
+ // whole field but not enough to know how to parse what's
+ // inside. For that we go to the schema.
+ WireFormatLite::WireType schema_wire_type =
+ WireFormatLite::WireTypeForFieldType(field.type);
+
+ // Handle packed repeated fields. SkipField would skip the
+ // whole length-delimited blob without letting us count the
+ // values, so we have to scan them ourselves.
+ if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
+ schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
+ // Handle packed repeated primitives.
+ int length;
+ if (!input->ReadVarintSizeAsInt(&length)) {
+ return errors::DataLoss("CollectField: Failed reading packed size");
+ }
+ Status st = collector->ReadPackedValues(input, field, length);
+ if (!st.ok()) {
+ return st;
+ }
+ return Status::OK();
+ }
+
+ // Read ordinary values, including strings, bytes, and messages.
+ if (wire_type != schema_wire_type) {
+ if (!WireFormatLite::SkipField(
+ input, WireFormatLite::MakeTag(field.number, wire_type))) {
+ return errors::DataLoss(
+ "CollectField: Failed skipping malformed field");
+ }
+ return Status::OK();
+ }
+ return collector->ReadValue(input, field);
+ }
+
+ string message_type_;
+ // Note that fields are sorted by increasing field number,
+ // which is not in general the order given by the user-specified
+ // field_names and output_types Op attributes.
+ std::vector<std::unique_ptr<const FieldInfo>> fields_;
+
+ // Owned_desc_pool_ is null when using descriptor_source=local.
+ std::unique_ptr<DescriptorPool> owned_desc_pool_;
+ DynamicMessageFactory message_factory_;
+ const Message* message_prototype_;
+
+ // True if decoding binary format, false if decoding text format.
+ bool is_binary_;
+
+ // True if the protos should be sanitized before parsing.
+ // Enables the initial protobuf sanitizer, which is much
+ // more expensive than the decoder. The flag defaults to true
+ // but can be set to false for trusted sources.
+ // TODO(nix): flip the default to false when the fast decoder
+ // has passed security review.
+ bool sanitize_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
+ DecodeProtoOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc
new file mode 100644
index 0000000000..3b02ae52a2
--- /dev/null
+++ b/tensorflow/core/kernels/encode_proto_op.cc
@@ -0,0 +1,591 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// EncodeProto is a TensorFlow Op which serializes tensors into
+// arbitrary protobufs.
+//
+// See the docstring in ../ops/encode_proto_op.cc for usage of the op.
+//
+// This implementation writes the serialized format using a handful of
+// calls from the WireFormatLite API.
+
+#include <memory>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/descriptors.h"
+
+namespace tensorflow {
+namespace {
+
+using ::tensorflow::protobuf::Descriptor;
+using ::tensorflow::protobuf::DescriptorPool;
+using ::tensorflow::protobuf::FieldDescriptor;
+using ::tensorflow::protobuf::internal::WireFormatLite;
+using ::tensorflow::protobuf::io::CodedOutputStream;
+using ::tensorflow::protobuf::io::StringOutputStream;
+
+// Computes the total serialized size for a packed repeated field.
+// For fixed-size types this can just multiply, but for variable-sized
+// types it has to iterate through the values in the tensor.
+template <WireFormatLite::FieldType FieldType, typename TensorT>
+size_t TotalPackedSize(const Tensor& input, int message_index, int size);
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kDoubleSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFloatSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFloatSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::Int64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::Int32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed64Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
+ int message_index,
+ int size) {
+ return size * WireFormatLite::kBoolSize;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::UInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size +=
+ WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
+ const Tensor& input, int message_index, int size) {
+ return size * WireFormatLite::kSFixed32Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>(
+ const Tensor& input, int message_index, int size) {
+ return size * WireFormatLite::kSFixed64Size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int32>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::SInt32Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+template <>
+size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input,
+ int message_index,
+ int size) {
+ size_t data_size = 0;
+ auto input_t = input.flat_inner_dims<int64>();
+ for (int64 i = 0; i < size; i++) {
+ data_size += WireFormatLite::SInt64Size(
+ input_t(static_cast<int64>(message_index), i));
+ }
+ return data_size;
+}
+
+// Writes a possibly repeated primitive field.
+// TensorFlow does not have unsigned types, so we decode them to signed and
+// encode them back to unsigned.
+template <typename TensorT, typename ProtoT,
+ WireFormatLite::FieldType FieldType,
+ void Writer(ProtoT, CodedOutputStream*)>
+void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto wire_type = WireFormatLite::WireTypeForFieldType(
+ WireFormatLite::FieldType(field_desc.type()));
+
+ auto input_t = input.flat_inner_dims<TensorT>();
+ if (field_desc.options().packed()) {
+ // Write the tag for the packed field.
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
+
+ // Write the total packed length.
+ size_t data_size =
+ TotalPackedSize<FieldType, TensorT>(input, message_index, size);
+ output->WriteVarint32(data_size);
+
+ // Write individual values.
+ for (int64 i = 0; i < size; i++) {
+ // Note implicit cast from signed to unsigned.
+ const ProtoT& value = input_t(static_cast<int64>(message_index), i);
+ Writer(value, output);
+ }
+ } else {
+ for (int64 i = 0; i < size; i++) {
+ WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
+
+ // Note implicit cast from signed to unsigned.
+ const ProtoT& value = input_t(static_cast<int64>(message_index), i);
+ Writer(value, output);
+ }
+ }
+}
+
+// Writes a possibly repeated string, bytes, or message field.
+template <typename T, void Writer(int, const T&, CodedOutputStream*)>
+void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto input_t = input.flat_inner_dims<T>();
+ for (int64 i = 0; i < size; i++) {
+ const T& value = input_t(static_cast<int64>(message_index), i);
+ // TODO(nix): there doesn't seem to be an inlined version of
+ // WireFormatLite::WriteString or its relatives, which might allow a
+ // small speedup.
+ Writer(field_desc.number(), value, output);
+ }
+}
+
+// Writes a group field.
+// Groups are treated like submessages, but tag-delimited
+// instead of length-delimited. WireFormatLite handles this
+// differently so we code it ourselves.
+void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ auto input_t = input.flat_inner_dims<string>();
+ for (int64 i = 0; i < size; i++) {
+ const string& value = input_t(static_cast<int64>(message_index), i);
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_START_GROUP, output);
+ // Note the use of WriteRaw instead of WriteString to skip the length.
+ output->WriteRaw(value.data(), value.size());
+ WireFormatLite::WriteTag(field_desc.number(),
+ WireFormatLite::WIRETYPE_END_GROUP, output);
+ }
+}
+
+// Writes a (possibly repeated) field into an output stream.
+// It is the caller's responsibility to ensure that the type of
+// the input tensor is compatible with the type of the proto
+// field descriptor, and that (message_index, size-1) is within
+// bounds.
+void WriteField(const FieldDescriptor& field_desc, const Tensor& input,
+ int message_index, int size, CodedOutputStream* output) {
+ DataType tf_type = input.dtype();
+
+ switch (field_desc.type()) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
+ WireFormatLite::WriteDoubleNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FLOAT:
+ switch (tf_type) {
+ case DataType::DT_FLOAT:
+ return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
+ WireFormatLite::WriteFloatNoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_DOUBLE:
+ return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
+ WireFormatLite::WriteFloatNoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_INT64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64,
+ WireFormatLite::WriteInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_UINT64:
+ return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
+ WireFormatLite::WriteUInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_INT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
+ WireFormatLite::WriteInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FIXED64:
+ return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
+ WireFormatLite::WriteFixed64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_FIXED32:
+ switch (tf_type) {
+ case DataType::DT_INT64:
+ return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32,
+ WireFormatLite::WriteFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32,
+ WireFormatLite::WriteFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_BOOL:
+ return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
+ WireFormatLite::WriteBoolNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_STRING:
+ return WriteVarLenField<string, WireFormatLite::WriteString>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_GROUP:
+ return WriteGroup(field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_MESSAGE:
+ return WriteVarLenField<string, WireFormatLite::WriteBytes>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_BYTES:
+ return WriteVarLenField<string, WireFormatLite::WriteBytes>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_UINT32:
+ switch (tf_type) {
+ case DataType::DT_INT64:
+ return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32,
+ WireFormatLite::WriteUInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case DataType::DT_INT32:
+ return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32,
+ WireFormatLite::WriteUInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ default:
+ return;
+ }
+ case WireFormatLite::TYPE_ENUM:
+ return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
+ WireFormatLite::WriteEnumNoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SFIXED32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
+ WireFormatLite::WriteSFixed32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SFIXED64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
+ WireFormatLite::WriteSFixed64NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SINT32:
+ return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
+ WireFormatLite::WriteSInt32NoTag>(
+ field_desc, input, message_index, size, output);
+ case WireFormatLite::TYPE_SINT64:
+ return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64,
+ WireFormatLite::WriteSInt64NoTag>(
+ field_desc, input, message_index, size, output);
+ // default: intentionally omitted in order to enable static checking.
+ }
+}
+
+// Checks that a Protobuf field is compatible with a TensorFlow datatype.
+// This is separated from WriteField to lift it out of the inner loop.
+bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) {
+ switch (field_desc.type()) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return tf_type == DataType::DT_DOUBLE;
+ case WireFormatLite::TYPE_FLOAT:
+ return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE;
+ case WireFormatLite::TYPE_INT64:
+ case WireFormatLite::TYPE_SFIXED64:
+ case WireFormatLite::TYPE_SINT64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_UINT64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_INT32:
+ case WireFormatLite::TYPE_ENUM:
+ case WireFormatLite::TYPE_SFIXED32:
+ case WireFormatLite::TYPE_SINT32:
+ return tf_type == DataType::DT_INT32;
+ case WireFormatLite::TYPE_FIXED64:
+ return tf_type == DataType::DT_INT64;
+ case WireFormatLite::TYPE_FIXED32:
+ case WireFormatLite::TYPE_UINT32:
+ return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32;
+ case WireFormatLite::TYPE_BOOL:
+ return tf_type == DataType::DT_BOOL;
+ case WireFormatLite::TYPE_STRING:
+ case WireFormatLite::TYPE_GROUP:
+ case WireFormatLite::TYPE_MESSAGE:
+ case WireFormatLite::TYPE_BYTES:
+ return tf_type == DataType::DT_STRING;
+ // default: intentionally omitted in order to enable static checking.
+ }
+ return false;
+}
+
+class EncodeProtoOp : public OpKernel {
+ public:
+ explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
+ string descriptor_source;
+ OP_REQUIRES_OK(context,
+ context->GetAttr("descriptor_source", &descriptor_source));
+ // We always get back a desc_pool, but we may not own it. If we own it,
+ // owned_desc_pool_ will be filled in.
+ DescriptorPool const* desc_pool;
+ OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
+ &desc_pool, &owned_desc_pool_));
+
+ string message_type;
+ OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
+ const Descriptor* message_desc =
+ desc_pool->FindMessageTypeByName(message_type);
+ OP_REQUIRES(context, message_desc != nullptr,
+ errors::InvalidArgument("No descriptor found for message type ",
+ message_type));
+
+ OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
+
+ // Gather the field descriptors for the given field_names.
+ field_descs_.resize(field_names_.size());
+ for (int i = 0; i < field_names_.size(); i++) {
+ const string& name = field_names_[i];
+ auto field_desc = message_desc->FindFieldByName(name);
+ OP_REQUIRES(context, field_desc != nullptr,
+ errors::InvalidArgument("Unknown field: ", name,
+ " in message type ", message_type));
+
+ field_descs_[i] = field_desc;
+ }
+
+ // Build a list of indices into field_descs sorted by increasing
+ // field_number. This will be used to output fields in sorted order,
+ // which is strongly encouraged when serializing protobufs.
+ sorted_field_index_.resize(field_names_.size());
+ // Start with the fields sorted by current index.
+ for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
+ // Then sort the field indices by their proto field number.
+ std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
+ [this](int a, int b) -> bool {
+ return field_descs_[a]->number() < field_descs_[b]->number();
+ });
+ }
+
+ void Compute(OpKernelContext* cx) override {
+ const Tensor* sizes_tensor;
+ OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor));
+
+ OpInputList values;
+ OP_REQUIRES_OK(cx, cx->input_list("values", &values));
+
+ OP_REQUIRES(cx, field_descs_.size() == values.size(),
+ errors::InvalidArgument(
+ "Length of inputs list must match field_names"));
+
+ // Check the arguments for consistency.
+ TensorShape common_prefix;
+ int message_count;
+ for (int i = 0; i < field_descs_.size(); i++) {
+ const Tensor& v = values[i];
+
+ // The type of each value tensor must match the corresponding field.
+ OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()),
+ errors::InvalidArgument(
+ "Incompatible type for field " + field_names_[i] +
+ ". Saw dtype: ",
+ DataTypeString(v.dtype()),
+ " but field type is: ", field_descs_[i]->type_name()));
+
+ // All value tensors must have the same shape prefix (i.e. batch size).
+ TensorShape shape_prefix = v.shape();
+ shape_prefix.RemoveDim(shape_prefix.dims() - 1);
+
+ // Do some initialization on the first input value. The rest will
+ // have to match this one.
+ if (i == 0) {
+ OP_REQUIRES(cx, v.dims() >= 1,
+ errors::InvalidArgument(
+ "Expected value to be at least a vector, saw shape: ",
+ v.shape().DebugString()));
+ common_prefix = shape_prefix;
+ message_count = common_prefix.num_elements();
+ } else {
+ OP_REQUIRES(cx, shape_prefix == common_prefix,
+ errors::InvalidArgument(
+ "Values must match up to the last dimension"));
+ }
+ }
+
+ TensorShape expected_sizes_shape = common_prefix;
+ expected_sizes_shape.AddDim(field_descs_.size());
+
+ OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape,
+ errors::InvalidArgument(
+ "sizes should be batch_size + [len(field_names)]. Saw: ",
+ sizes_tensor->shape().DebugString(),
+ " but expected: ", expected_sizes_shape.DebugString()));
+
+ auto sizes = sizes_tensor->flat_inner_dims<int32>();
+
+ for (int i = 0; i < field_descs_.size(); ++i) {
+ const Tensor& v = values[i];
+ int max_size = v.dim_size(v.dims() - 1);
+
+ // The last dimension of a value tensor must be greater than the
+ // corresponding
+ // size in the sizes tensor.
+ for (int message_index = 0; message_index < message_count;
+ message_index++) {
+ OP_REQUIRES(
+ cx, sizes(message_index, i) <= max_size,
+ errors::InvalidArgument(
+ "Size to write must not be larger than value tensor; but saw: ",
+ sizes(message_index, i), " > ", max_size, " at message ",
+ message_index, " field ", i));
+ }
+ }
+
+ // This pointer is owned by the context.
+ Tensor* output_tensor;
+ OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor));
+
+ auto bufs = output_tensor->flat<string>();
+ for (int message_index = 0; message_index < message_count;
+ message_index++) {
+ // TODO(nix): possibly optimize allocation here by calling
+ // bufs(message_index).reserve(DEFAULT_BUF_SIZE);
+ StringOutputStream output_string(&bufs(message_index));
+ CodedOutputStream out(&output_string);
+ // Write fields in ascending field_number order.
+ for (int i : sorted_field_index_) {
+ auto& field_desc = *field_descs_[i];
+ const Tensor& v = values[i];
+ int size = sizes(message_index, i);
+ if (!size) continue;
+ WriteField(field_desc, v, message_index, size, &out);
+ }
+ }
+ }
+
+ private:
+ std::vector<string> field_names_;
+ std::vector<const FieldDescriptor*> field_descs_;
+
+ // Owned_desc_pool_ is null when using descriptor_source=local.
+ std::unique_ptr<DescriptorPool> owned_desc_pool_;
+
+ // Contains indices into field_names_, sorted by field number since
+ // that's the order of writing.
+ std::vector<int> sorted_field_index_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/rpc_op.cc b/tensorflow/core/kernels/rpc_op.cc
new file mode 100644
index 0000000000..2447ef5040
--- /dev/null
+++ b/tensorflow/core/kernels/rpc_op.cc
@@ -0,0 +1,129 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// RpcOp is a TensorFlow op that sends and receives arbitrary messages.
+//
+// See docs in ../ops/rpc_op.cc.
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/rpc/call_container.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+
+class RpcOp : public AsyncOpKernel {
+ public:
+ explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_));
+ OP_REQUIRES(context, !protocol_.empty(),
+ errors::InvalidArgument("protocol must be non-empty."));
+ bool fail_fast;
+ OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast));
+ int64 timeout_in_ms;
+ OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms));
+
+ RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn =
+ RPCFactoryRegistry::Global()->Get(protocol_);
+ OP_REQUIRES(context, rpc_factory_fn != nullptr,
+ errors::InvalidArgument("The protocol ", protocol_,
+ " was not recognized."));
+
+ rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms));
+ }
+
+ ~RpcOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor& address_t = ctx->input(0);
+ const Tensor& method_t = ctx->input(1);
+ const Tensor& request_t = ctx->input(2);
+
+ OP_REQUIRES_ASYNC(
+ ctx, address_t.dims() == 0 || address_t.dims() == 1,
+ errors::InvalidArgument("address must be a scalar or vector."), done);
+ OP_REQUIRES_ASYNC(
+ ctx, method_t.dims() == 0 || method_t.dims() == 1,
+ errors::InvalidArgument("method must be a scalar or vector."), done);
+ OP_REQUIRES_ASYNC(
+ ctx, request_t.dims() == 0 || request_t.dims() == 1,
+ errors::InvalidArgument("request must be a scalar or vector."), done);
+
+ TensorShape output_shape({});
+ for (const Tensor& t : {address_t, method_t, request_t}) {
+ if (t.dims() == 1) {
+ OP_REQUIRES_ASYNC(
+ ctx,
+ output_shape.dims() == 0 ||
+ output_shape.dim_size(0) == t.dim_size(0),
+ errors::InvalidArgument(
+ "Input vector shapes don't match: ", output_shape.DebugString(),
+ " vs. ", t.shape().DebugString()),
+ done);
+ output_shape = t.shape();
+ }
+ }
+
+ Tensor* response_t;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(0, output_shape, &response_t), done);
+
+ const bool try_rpc = (ctx->num_outputs() > 1);
+
+ Tensor* status_code_t = nullptr;
+ Tensor* status_message_t = nullptr;
+ if (try_rpc) {
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(1, output_shape, &status_code_t), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->allocate_output(2, output_shape, &status_message_t), done);
+ }
+
+ if (request_t.NumElements() == 0) {
+ // Special case, we finished early!
+ done();
+ return;
+ }
+
+ int64 num_elements = output_shape.num_elements();
+
+ rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t,
+ try_rpc, response_t, status_code_t, status_message_t,
+ std::move(done));
+ }
+
+ private:
+ string protocol_;
+ std::unique_ptr<RPCFactory> rpc_factory_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(RpcOp);
+};
+
+REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp);
+REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/decode_proto_ops.cc b/tensorflow/core/ops/decode_proto_ops.cc
new file mode 100644
index 0000000000..3f6fb2f582
--- /dev/null
+++ b/tensorflow/core/ops/decode_proto_ops.cc
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+REGISTER_OP("DecodeProtoV2")
+ .Input("bytes: string")
+ .Attr("message_type: string")
+ .Attr("field_names: list(string)")
+ .Attr("output_types: list(type) >= 0")
+ .Attr("descriptor_source: string = 'local://'")
+ .Attr("message_format: string = 'binary'")
+ .Attr("sanitize: bool = false")
+ .Output("sizes: int32")
+ .Output("values: output_types")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+
+ std::vector<tensorflow::DataType> output_types;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types));
+
+ ShapeHandle sizes;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(input, c->Vector(output_types.size()), &sizes));
+ c->set_output(0, sizes);
+
+ // TODO(nix): to do the best possible job of shape inference, we
+ // should examine the proto descriptors here in order to set shape
+ // indices to 1 instead of unknown for optional or required fields.
+ // Any general-purpose code will have to handle the unknown case,
+ // but there might be XLA code that could be sped up with the additional
+ // knowledge.
+ for (int i = 0; i < output_types.size(); ++i) {
+ ShapeHandle values;
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(input, c->Vector(c->UnknownDim()), &values));
+ c->set_output(i + 1, values);
+ }
+
+ return Status::OK();
+ });
+
+// TODO(nix): Consider adding an additional input argument that truncates
+// repeated fields to a maximum count. For now this could be done by passing
+// the output through tf.slice.
+
+// TODO(nix): define missing value behavior.
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/encode_proto_ops.cc b/tensorflow/core/ops/encode_proto_ops.cc
new file mode 100644
index 0000000000..f5ec3056e3
--- /dev/null
+++ b/tensorflow/core/ops/encode_proto_ops.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+REGISTER_OP("EncodeProto")
+ .Input("sizes: int32")
+ .Input("values: Tinput_types")
+ .Attr("field_names: list(string)")
+ .Attr("message_type: string")
+ .Attr("descriptor_source: string = 'local://'")
+ .Attr("Tinput_types: list(type)")
+ .Output("bytes: string")
+ .SetShapeFn([](InferenceContext* c) {
+ int first_field_index = 1;
+ int num_fields = c->num_inputs() - 1;
+
+ ShapeHandle output;
+ for (int i = num_fields - 1; i >= 0; --i) {
+ ShapeHandle input = c->input(first_field_index + i);
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input));
+ ShapeHandle inner;
+ TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &inner));
+ TF_RETURN_IF_ERROR(c->Merge(inner, output, &output));
+ }
+
+ c->set_output(0, output);
+ return Status::OK();
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/rpc_ops.cc b/tensorflow/core/ops/rpc_ops.cc
new file mode 100644
index 0000000000..72fda5e6eb
--- /dev/null
+++ b/tensorflow/core/ops/rpc_ops.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using tensorflow::shape_inference::DimensionHandle;
+using tensorflow::shape_inference::InferenceContext;
+using tensorflow::shape_inference::ShapeHandle;
+
+Status RpcShapeOp(InferenceContext* c, bool try_rpc) {
+ ShapeHandle address;
+ ShapeHandle method;
+ ShapeHandle request;
+ ShapeHandle output;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address));
+ if (c->Rank(address) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, address, &output));
+ }
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method));
+ if (c->Rank(method) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, method, &output));
+ }
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request));
+ if (c->Rank(request) == 1) {
+ TF_RETURN_IF_ERROR(c->Merge(output, request, &output));
+ }
+ if (!c->RankKnown(output)) {
+ output = request;
+ }
+ c->set_output(0, output); // response
+ if (try_rpc) {
+ c->set_output(1, output); // status_code
+ c->set_output(2, output); // status_message
+ }
+ return Status::OK();
+}
+
+REGISTER_OP("Rpc")
+ .Input("address: string")
+ .Input("method: string")
+ .Input("request: string")
+ .Attr("protocol: string = ''")
+ .Attr("fail_fast: bool = true")
+ .Attr("timeout_in_ms: int = 0")
+ .Output("response: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ return RpcShapeOp(c, /*try_rpc=*/false);
+ });
+
+REGISTER_OP("TryRpc")
+ .Input("address: string")
+ .Input("method: string")
+ .Input("request: string")
+ .Attr("protocol: string = ''")
+ .Attr("fail_fast: bool = true")
+ .Attr("timeout_in_ms: int = 0")
+ .Output("response: string")
+ .Output("status_code: int32")
+ .Output("status_message: string")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ return RpcShapeOp(c, /*try_rpc=*/true);
+ });
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD
new file mode 100644
index 0000000000..ade14ed162
--- /dev/null
+++ b/tensorflow/core/util/proto/BUILD
@@ -0,0 +1,62 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "decode",
+ hdrs = ["decode.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "descriptors",
+ srcs = ["descriptors.cc"],
+ hdrs = ["descriptors.h"],
+ deps = [
+ ":descriptor_pool_registry",
+ ":local_descriptor_pool_registration",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "descriptor_pool_registry",
+ srcs = ["descriptor_pool_registry.cc"],
+ hdrs = ["descriptor_pool_registry.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "descriptor_pool_registry_test",
+ srcs = ["descriptor_pool_registry_test.cc"],
+ deps = [
+ ":descriptor_pool_registry",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+# Depending on this target adds support for using the special
+# value "local://" (or "") for descriptor source, in which case
+# descriptors linked into the code will be searched.
+cc_library(
+ name = "local_descriptor_pool_registration",
+ srcs = ["local_descriptor_pool_registration.cc"],
+ deps = [
+ ":descriptor_pool_registry",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h
new file mode 100644
index 0000000000..74634a356a
--- /dev/null
+++ b/tensorflow/core/util/proto/decode.h
@@ -0,0 +1,592 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Inline functions for parsing the protocol buffers wire format.
+//
+// These functions have been optimized at the expense of safety.
+// They are broken out into a separate file for readability but are
+// not intended for use by clients other than the decode_proto op.
+//
+// The calling code in the decode_proto op does some fairly
+// complicated things to ensure that this code is called
+// safely. Changes to this code should be thoroughly fuzz tested.
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace internal {
+
+using tensorflow::protobuf::internal::WireFormatLite;
+using tensorflow::protobuf::io::CodedInputStream;
+using tensorflow::protobuf::io::CodedOutputStream;
+using tensorflow::protobuf::io::StringOutputStream;
+
+// Converts an uint64 to an int64 without loss of information.
+// Unsigned values greater than INT64_MAX are represented as
+// negative numbers by wrapping (same as twos-complement bit equivalence).
+inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) {
+ // For a detailed explanation of why this works to wrap unsigned ints, see
+ // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
+ // Both if tests should be optimized out.
+ if (unsigned_value <= INT64_MAX) {
+ return static_cast<int64>(unsigned_value);
+ }
+ // The C++ spec allows an architecture where this test is required.
+ if (unsigned_value >= INT64_MIN) {
+ return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN;
+ }
+ return 0; // This should never occur.
+}
+
+// Converts an uint32 to an int32 without loss of information.
+// Unsigned values greater than INT_MAX are represented as
+// negative numbers by wrapping (same as twos-complement bit equivalence).
+inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) {
+ // For a detailed explanation of why this works to wrap unsigned ints, see
+ // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior
+ // Both if tests should be optimized out.
+ if (unsigned_value <= INT_MAX) {
+ return static_cast<int32>(unsigned_value);
+ }
+ // The C++ spec allows an architecture where this test is required.
+ if (unsigned_value >= INT_MIN) {
+ return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN;
+ }
+ return 0; // This should never occur.
+}
+
+// Reads a single varint32 from a byte array.
+// It is the caller's responsibility to ensure that there is enough
+// space in the buffer.
+// The ok value will be set to false if the buffer does not contain
+// a valid varint.
+inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
+ uint64* value);
+
+// Reads a single varint32 from a byte array.
+// It is the caller's responsibility to ensure that there is enough
+// space in the buffer.
+// The ok value will be set to false if the buffer does not contain
+// a valid varint.
+// This is slightly less efficient than the private version in
+// coded_stream.cc but we duplicate less code by calling
+// the 64 bit version instead of copying the code.
+inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok,
+ uint32* value) {
+ uint64 tmp;
+ const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp);
+ *value = tmp & 0xffffffff;
+ return buf;
+}
+
+// Reads a single proto field value from a byte array into an array.
+// The array is part of a Tensor that was allocated by the caller
+// with type TensorType, while DeclaredType is the proto field type.
+template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
+const uint8* ReadFromArray(const uint8* buf, TensorType* value);
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int32>(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = WrapUnsignedAsSigned64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = WrapUnsignedAsSigned32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int64>(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = WireFormatLite::ZigZagDecode32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
+ const uint8* buf, int64* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = WireFormatLite::ZigZagDecode64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
+ WireFormatLite::TYPE_FIXED32>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, int32* value) {
+ uint32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
+ WireFormatLite::TYPE_FIXED32>(
+ buf, &temp);
+ *value = WrapUnsignedAsSigned32(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
+ const uint8* buf, int64* value) {
+ protobuf_uint64 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
+ WireFormatLite::TYPE_FIXED64>(
+ buf, &temp);
+ *value = WrapUnsignedAsSigned64(temp);
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
+ const uint8* buf, int32* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<int32,
+ WireFormatLite::TYPE_SFIXED32>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>(
+ const uint8* buf, int64* value) {
+ protobuf_int64 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64,
+ WireFormatLite::TYPE_SFIXED64>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
+ const uint8* buf, float* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<float,
+ WireFormatLite::TYPE_FLOAT>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
+ const uint8* buf, double* value) {
+ return WireFormatLite::ReadPrimitiveFromArray<double,
+ WireFormatLite::TYPE_DOUBLE>(
+ buf, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>(
+ const uint8* buf, bool* value) {
+ uint64 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
+ *value = temp != 0;
+ return buf;
+}
+
+template <>
+inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>(
+ const uint8* buf, int* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int>(temp);
+ return buf;
+}
+
+// Reads packed values from an array.
+// Stride is set to 1 for repeated fields, and 0 for non-repeated fields
+// (where any value overwrites previous values).
+template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
+inline int ReadPackedPrimitives(const void* bufp, const size_t len,
+ const int index, const int stride,
+ void* datap) {
+ const uint8* buf = reinterpret_cast<const uint8*>(bufp);
+ const uint8* bound = buf + len;
+ TensorType* data = reinterpret_cast<TensorType*>(datap) + index;
+ int count;
+
+ // This could overrun the bound by stride-1. This is defended
+ // against in the caller, where it ensures that the input buffer
+ // contains complete values.
+ for (count = 0; buf < bound; count += stride) {
+ buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count);
+ }
+ return count;
+}
+
+// Reads a primitive value field from a serialized proto.
+// The value is parsed from the serialized format, then static_cast
+// to the desired type for TensorFlow and stored.
+template <class ValueType, class TensorType,
+ enum WireFormatLite::FieldType DeclaredType>
+inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) {
+ ValueType v;
+ if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) {
+ return errors::DataLoss("Failed reading primitive");
+ }
+
+ reinterpret_cast<TensorType*>(data)[index] = v;
+ return Status::OK();
+}
+
+// Reads a string, submessage, or other variable-length field from a
+// serialized proto.
+// May read all or part of a repeated field.
+inline Status ReadBytes(CodedInputStream* input, int index, void* datap) {
+ string* data = reinterpret_cast<string*>(datap) + index;
+ if (!WireFormatLite::ReadBytes(input, data)) {
+ return errors::DataLoss("Failed reading bytes");
+ }
+ return Status::OK();
+}
+
+// Reads a tag-delimited field (TYPE_GROUP) from a serialized proto,
+// as a bytestring.
+inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
+ int index, void* datap) {
+ // WireFormatLite::SkipField has an option to emit the
+ // skipped bytes to an output stream. We could do better by implementing our
+ // own scanner but this is simpler for now.
+ // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying
+ // on input->IsFlat() == true and using input->GetDirectBufferPointer()
+ // with input->CurrentPosition().
+ string* data = reinterpret_cast<string*>(datap) + index;
+ StringOutputStream string_stream(data);
+ CodedOutputStream out(&string_stream);
+ if (!WireFormatLite::SkipField(
+ input,
+ WireFormatLite::MakeTag(field_number,
+ WireFormatLite::WIRETYPE_START_GROUP),
+ &out)) {
+ return errors::DataLoss("Failed reading group");
+ }
+ return Status::OK();
+}
+
+// Reads a single field value from a CodedInputStream into a tensor.
+inline Status ReadValue(CodedInputStream* input,
+ WireFormatLite::FieldType field_type, int field_number,
+ DataType dtype, int index, void* datap) {
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
+ input, index, datap);
+ case WireFormatLite::TYPE_FLOAT:
+ if (dtype == DataType::DT_FLOAT) {
+ return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_DOUBLE) {
+ return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FLOAT");
+ case WireFormatLite::TYPE_INT64:
+ return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
+ input, index, datap);
+ case WireFormatLite::TYPE_UINT64:
+ return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
+ input, index, datap);
+ case WireFormatLite::TYPE_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_FIXED64:
+ return ReadPrimitive<protobuf_uint64, int64,
+ WireFormatLite::TYPE_FIXED64>(input, index, datap);
+ case WireFormatLite::TYPE_FIXED32:
+ if (dtype == DataType::DT_INT64) {
+ return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_INT32) {
+ return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FIXED32");
+ case WireFormatLite::TYPE_BOOL:
+ return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
+ datap);
+ case WireFormatLite::TYPE_STRING:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_GROUP:
+ return ReadGroupBytes(input, field_number, index, datap);
+ case WireFormatLite::TYPE_MESSAGE:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_BYTES:
+ return ReadBytes(input, index, datap);
+ case WireFormatLite::TYPE_UINT32:
+ if (dtype == DataType::DT_INT64) {
+ return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ }
+ if (dtype == DataType::DT_INT32) {
+ return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_UINT32");
+ case WireFormatLite::TYPE_ENUM:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SFIXED32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SFIXED64:
+ return ReadPrimitive<protobuf_int64, int64,
+ WireFormatLite::TYPE_SFIXED64>(input, index, datap);
+ case WireFormatLite::TYPE_SINT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ case WireFormatLite::TYPE_SINT64:
+ return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
+ input, index, datap);
+ // default: intentionally omitted in order to enable static checking.
+ }
+ // Unreachable.
+ return errors::DataLoss("Failed reading unknown wire type");
+}
+
+// Reads and stores a length-delimited list of values.
+inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
+ const WireFormatLite::FieldType field_type,
+ const int field_number, const DataType dtype,
+ const int stride, int* index, void* data) {
+ // Dispatch to the appropriately typed field reader based on the
+ // schema type.
+ switch (field_type) {
+ case WireFormatLite::TYPE_DOUBLE:
+ *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FLOAT:
+ *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_UINT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FIXED64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_FIXED32:
+ if (dtype == DataType::DT_INT64) {
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ if (dtype == DataType::DT_INT32) {
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_FIXED32");
+ case WireFormatLite::TYPE_BOOL:
+ *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_STRING:
+ case WireFormatLite::TYPE_GROUP:
+ case WireFormatLite::TYPE_MESSAGE:
+ case WireFormatLite::TYPE_BYTES:
+ return errors::DataLoss("Non-primitive type encountered as packed");
+ case WireFormatLite::TYPE_UINT32:
+ if (dtype == DataType::DT_INT64) {
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ if (dtype == DataType::DT_INT32) {
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ }
+ // Any case that reaches this point should have triggered an error
+ // already.
+ return errors::DataLoss("Failed reading TYPE_UINT32");
+ case WireFormatLite::TYPE_ENUM:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case WireFormatLite::TYPE_SFIXED32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SFIXED64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SINT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+
+ case WireFormatLite::TYPE_SINT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ // default: intentionally omitted in order to enable static checking.
+ }
+ // Unreachable.
+ return errors::DataLoss("Failed reading unknown wire type");
+}
+
+// Reads a varint from the given buffer, write it to *value, and return the
+// new buffer pointer.
+// This was copied from coded_stream.cc where it is private.
+// Important: This routine may read as much as kMaxVarintBytes from
+// the buffer. It is the caller's responsibility to make sure that there is
+// enough space in the buffer.
+inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok,
+ uint64* value) {
+ const uint8* ptr = buffer;
+ uint32 b;
+
+ // Splitting into 32-bit pieces gives better performance on 32-bit
+ // processors.
+ uint32 part0 = 0, part1 = 0, part2 = 0;
+
+ b = *(ptr++);
+ part0 = b;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80;
+ b = *(ptr++);
+ part0 += b << 7;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 7;
+ b = *(ptr++);
+ part0 += b << 14;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 14;
+ b = *(ptr++);
+ part0 += b << 21;
+ if (!(b & 0x80)) goto done;
+ part0 -= 0x80 << 21;
+ b = *(ptr++);
+ part1 = b;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80;
+ b = *(ptr++);
+ part1 += b << 7;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 7;
+ b = *(ptr++);
+ part1 += b << 14;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 14;
+ b = *(ptr++);
+ part1 += b << 21;
+ if (!(b & 0x80)) goto done;
+ part1 -= 0x80 << 21;
+ b = *(ptr++);
+ part2 = b;
+ if (!(b & 0x80)) goto done;
+ part2 -= 0x80;
+ b = *(ptr++);
+ part2 += b << 7;
+ if (!(b & 0x80)) goto done;
+ // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0.
+
+ // We have overrun the maximum size of a varint (10 bytes). Assume
+ // the data is corrupt.
+ *ok = false;
+ return ptr;
+
+done:
+ *ok = true;
+ *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) |
+ (static_cast<uint64>(part2) << 56);
+ return ptr;
+}
+
+} // namespace internal
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.cc b/tensorflow/core/util/proto/descriptor_pool_registry.cc
new file mode 100644
index 0000000000..5f0423f76b
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+namespace tensorflow {
+
+DescriptorPoolRegistry* DescriptorPoolRegistry::Global() {
+ static DescriptorPoolRegistry* registry = new DescriptorPoolRegistry;
+ return registry;
+}
+
+DescriptorPoolRegistry::DescriptorPoolFn* DescriptorPoolRegistry::Get(
+ const string& source) {
+ auto found = fns_.find(source);
+ if (found == fns_.end()) return nullptr;
+ return &found->second;
+}
+
+void DescriptorPoolRegistry::Register(
+ const string& source,
+ const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
+ auto existing = Get(source);
+ CHECK_EQ(existing, nullptr)
+ << "descriptor pool for source: " << source << " already registered";
+ fns_.insert(std::pair<const string&, DescriptorPoolFn>(source, pool_fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.h b/tensorflow/core/util/proto/descriptor_pool_registry.h
new file mode 100644
index 0000000000..66c20e9e41
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry.h
@@ -0,0 +1,76 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+class DescriptorPoolRegistry {
+ public:
+ typedef std::function<Status(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool)>
+ DescriptorPoolFn;
+
+ // Returns a pointer to a global DescriptorPoolRegistry object.
+ static DescriptorPoolRegistry* Global();
+
+ // Returns a pointer to a descriptor pool function for the given source.
+ DescriptorPoolFn* Get(const string& source);
+
+ // Registers a descriptor pool factory.
+ void Register(const string& source, const DescriptorPoolFn& pool_fn);
+
+ private:
+ std::map<string, DescriptorPoolFn> fns_;
+};
+
+namespace descriptor_pool_registration {
+
+class DescriptorPoolRegistration {
+ public:
+ DescriptorPoolRegistration(
+ const string& source,
+ const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) {
+ DescriptorPoolRegistry::Global()->Register(source, pool_fn);
+ }
+};
+
+} // namespace descriptor_pool_registration
+
+#define REGISTER_DESCRIPTOR_POOL(source, pool_fn) \
+ REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(__COUNTER__, source, pool_fn)
+
+#define REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(ctr, source, pool_fn) \
+ REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn)
+
+#define REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) \
+ static descriptor_pool_registration::DescriptorPoolRegistration \
+ descriptor_pool_registration_fn_##ctr(source, pool_fn)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_
diff --git a/tensorflow/core/util/proto/descriptor_pool_registry_test.cc b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc
new file mode 100644
index 0000000000..a6899998ab
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Value {
+ static Status Function(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ return Status::OK();
+ }
+};
+
+REGISTER_DESCRIPTOR_POOL("TEST POOL 1", Value::Function);
+REGISTER_DESCRIPTOR_POOL("TEST POOL 2", Value::Function);
+} // namespace
+
+TEST(DescriptorPoolRegistryTest, TestBasic) {
+ EXPECT_EQ(DescriptorPoolRegistry::Global()->Get("NON-EXISTENT"), nullptr);
+ auto pool1 = DescriptorPoolRegistry::Global()->Get("TEST POOL 1");
+ EXPECT_NE(pool1, nullptr);
+ auto pool2 = DescriptorPoolRegistry::Global()->Get("TEST POOL 2");
+ EXPECT_NE(pool2, nullptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc
new file mode 100644
index 0000000000..271c85efd8
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptors.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/reader_op_kernel.h"
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+#include "tensorflow/core/util/proto/descriptors.h"
+
+namespace tensorflow {
+namespace {
+
+// Build a `DescriptorPool` from the named file or URI. The file or URI
+// must be available to the current TensorFlow environment.
+//
+// The file must contiain a serialized `FileDescriptorSet`. See
+// `GetDescriptorPool()` for more information.
+Status GetDescriptorPoolFromFile(
+ tensorflow::Env* env, const string& filename,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ Status st = env->FileExists(filename);
+ if (!st.ok()) {
+ return st;
+ }
+
+ // Read and parse the FileDescriptorSet.
+ tensorflow::protobuf::FileDescriptorSet descs;
+ std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf;
+ st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf);
+ if (!st.ok()) {
+ return st;
+ }
+ if (!descs.ParseFromArray(buf->data(), buf->length())) {
+ return errors::InvalidArgument(
+ "descriptor_source contains invalid FileDescriptorSet: ", filename);
+ }
+
+ // Build a DescriptorPool from the FileDescriptorSet.
+ owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool());
+ for (const auto& filedesc : descs.file()) {
+ if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) {
+ return errors::InvalidArgument(
+ "Problem loading FileDescriptorProto (missing dependencies?): ",
+ filename);
+ }
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+Status GetDescriptorPool(
+ tensorflow::Env* env, string const& descriptor_source,
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ // Attempt to lookup the pool in the registry.
+ auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source);
+ if (pool_fn != nullptr) {
+ return (*pool_fn)(desc_pool, owned_desc_pool);
+ }
+
+ // If there is no pool function registered for the given source, let the
+ // runtime find the file or URL.
+ Status status =
+ GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool);
+ if (status.ok()) {
+ *desc_pool = owned_desc_pool->get();
+ }
+ *desc_pool = owned_desc_pool->get();
+ return status;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h
new file mode 100644
index 0000000000..92ee8997ab
--- /dev/null
+++ b/tensorflow/core/util/proto/descriptors.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
+#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+class Env;
+class Status;
+
+// Get a `DescriptorPool` object from the named `descriptor_source`.
+// `descriptor_source` may be a path to a file accessible to TensorFlow, in
+// which case it is parsed as a `FileDescriptorSet` and used to build the
+// `DescriptorPool`.
+//
+// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if
+// the caller should take ownership.
+extern tensorflow::Status GetDescriptorPool(
+ tensorflow::Env* env, string const& descriptor_source,
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_
diff --git a/tensorflow/core/util/proto/local_descriptor_pool_registration.cc b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc
new file mode 100644
index 0000000000..48fe0102d0
--- /dev/null
+++ b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc
@@ -0,0 +1,39 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/proto/descriptor_pool_registry.h"
+
+namespace tensorflow {
+namespace {
+
+struct LocalDescriptorPool {
+ static Status Function(
+ tensorflow::protobuf::DescriptorPool const** desc_pool,
+ std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) {
+ *desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool();
+ if (*desc_pool == nullptr) {
+ return errors::InvalidArgument("Problem loading protobuf generated_pool");
+ }
+ return Status::OK();
+ }
+};
+
+REGISTER_DESCRIPTOR_POOL("", LocalDescriptorPool::Function);
+REGISTER_DESCRIPTOR_POOL("local://", LocalDescriptorPool::Function);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD
new file mode 100644
index 0000000000..f0f161ecc0
--- /dev/null
+++ b/tensorflow/core/util/rpc/BUILD
@@ -0,0 +1,48 @@
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "call_container",
+ hdrs = ["call_container.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory",
+ srcs = ["rpc_factory.cc"],
+ hdrs = ["rpc_factory.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
+ name = "rpc_factory_registry",
+ srcs = ["rpc_factory_registry.cc"],
+ hdrs = ["rpc_factory_registry.h"],
+ deps = [
+ ":rpc_factory",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_cc_test(
+ name = "rpc_factory_registry_test",
+ srcs = ["rpc_factory_registry_test.cc"],
+ deps = [
+ ":rpc_factory_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
new file mode 100644
index 0000000000..7f36056797
--- /dev/null
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -0,0 +1,90 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
+
+#include <list>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+namespace tensorflow {
+
+template <typename Call>
+class CallContainer {
+ public:
+ explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
+ bool try_rpc, AsyncOpKernel::DoneCallback done,
+ CancellationToken token)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(token),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ delete this;
+ });
+
+ // Subtract reference count from the initial creation.
+ core::ScopedUnref unref(reffed_status_callback_);
+
+ for (int i = 0; i < num_calls; ++i) {
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ }
+
+ std::list<Call>* calls() { return &calls_; }
+
+ void StartCancel() {
+ // Once this loop is done, can no longer assume anything is valid
+ // because "delete this" may have been immediately called.
+ // Nothing should run after this loop.
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+ }
+
+ void Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+ }
+
+ private:
+ OpKernelContext* ctx_;
+ std::list<Call> calls_;
+ const AsyncOpKernel::DoneCallback done_;
+ const CancellationToken token_;
+ const bool fail_fast_;
+ const bool try_rpc_;
+
+ // Performs its own reference counting.
+ ReffedStatusCallback* reffed_status_callback_;
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc
new file mode 100644
index 0000000000..8530f02b6e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/numbers.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+template <>
+bool GetEnvVar(const char* key, const string& default_value, string* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ } else {
+ *value = env_value;
+ }
+ return true;
+}
+
+template <>
+bool GetEnvVar(const char* key, const int64& default_value, int64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strto64(env_value, value);
+}
+
+template <>
+bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) {
+ const char* env_value = std::getenv(key);
+ if (!env_value || env_value[0] == '\0') {
+ *value = default_value;
+ return true;
+ }
+ return strings::safe_strtou64(env_value, value);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
new file mode 100644
index 0000000000..9bf078c0f4
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+namespace tensorflow {
+
+// Return the environment variable `key`. If the variable is not set,
+// use the default value. If it is set but could not be parsed,
+// return `false`. Otherwise set `value` and return `true`.
+template <typename T>
+bool GetEnvVar(const char* key, const T& default_value, T* value);
+
+class RPCFactory {
+ public:
+ RPCFactory() {}
+ virtual ~RPCFactory() {}
+
+ // Start a Call() to methods `method_t` at addresses `address_t` with
+ // request strings from `request_t`. Any of these may be scalar
+ // Tensors, in which case the operands are broadcasted.
+ // Upon completion of all requests, `response_t` will be populated.
+ //
+ // If `try_rpc` is `true`, then `status_message_t` and
+ // `status_code_t` will be populated as well.
+ //
+ // If `try_rpc` is `false`, then `status_message_t` and
+ // `status_code_t` are ignored (and may be nullptr). Instead, the
+ // status of any failed call will be propagated to the op.
+ //
+ // REQUIRES:
+ // - `response_t` is not null, and is a string Tensor with the same shape as
+ // `request_t`.
+ //
+ // If `try_rpc` is `true`:
+ // - `status_code_t` and `status_message_t` are not null.
+ // - `status_code_t` is an int32 Tensor with the same shape as
+ // `request_t`.
+ // - `status_message_t` is a string Tensor with the same shape as
+ // `request_t`.
+ virtual void Call(OpKernelContext* ctx, int64 num_elements,
+ const Tensor& address_t, const Tensor& method_t,
+ const Tensor& request_t, const bool try_rpc,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t,
+ AsyncOpKernel::DoneCallback done) = 0;
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc
new file mode 100644
index 0000000000..a148b5c04d
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.cc
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+
+namespace tensorflow {
+
+RPCFactoryRegistry* RPCFactoryRegistry::Global() {
+ static RPCFactoryRegistry* registry = new RPCFactoryRegistry;
+ return registry;
+}
+
+RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get(
+ const string& protocol) {
+ auto found = fns_.find(protocol);
+ if (found == fns_.end()) return nullptr;
+ return &found->second;
+}
+
+void RPCFactoryRegistry::Register(const string& protocol,
+ const RPCFactoryFn& factory_fn) {
+ auto existing = Get(protocol);
+ CHECK_EQ(existing, nullptr)
+ << "RPC factory for protocol: " << protocol << " already registered";
+ fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn));
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h
new file mode 100644
index 0000000000..2635a4012e
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry.h
@@ -0,0 +1,72 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
+
+#include <map>
+#include <string>
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/util/rpc/rpc_factory.h"
+
+namespace tensorflow {
+
+class RPCFactoryRegistry {
+ public:
+ typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms)>
+ RPCFactoryFn;
+
+ // Returns a pointer to a global RPCFactoryRegistry object.
+ static RPCFactoryRegistry* Global();
+
+ // Returns a pointer to an function that creates an RPC factory for the given
+ // protocol.
+ RPCFactoryFn* Get(const string& protocol);
+
+ // Registers a function that creates and RPC factory for the given protocol.
+ // The function should transfer the ownership of the factory to its caller.
+ void Register(const string& protocol, const RPCFactoryFn& factory_fn);
+
+ private:
+ std::map<string, RPCFactoryFn> fns_;
+};
+
+namespace rpc_factory_registration {
+
+class RPCFactoryRegistration {
+ public:
+ RPCFactoryRegistration(const string& protocol,
+ const RPCFactoryRegistry::RPCFactoryFn& factory_fn) {
+ RPCFactoryRegistry::Global()->Register(protocol, factory_fn);
+ }
+};
+
+} // namespace rpc_factory_registration
+
+#define REGISTER_RPC_FACTORY(protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \
+ REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn)
+
+#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \
+ static rpc_factory_registration::RPCFactoryRegistration \
+ rpc_factory_registration_fn_##ctr(protocol, factory_fn)
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
new file mode 100644
index 0000000000..cfd0f95016
--- /dev/null
+++ b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/util/rpc/rpc_factory_registry.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+struct Value {
+ static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast,
+ int64 timeout_in_ms) {
+ return nullptr;
+ }
+};
+
+REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function);
+REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function);
+} // namespace
+
+TEST(RPCFactoryRegistryTest, TestBasic) {
+ EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr);
+ auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1");
+ EXPECT_NE(factory1, nullptr);
+ auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2");
+ EXPECT_NE(factory2, nullptr);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 01962fcf44..a22b9f40b1 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3370,6 +3370,7 @@ tf_py_wrap_cc(
"//tensorflow/c:python_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api",
+ "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/distributed_runtime/rpc:grpc_session",
"//tensorflow/core/grappler:grappler_item",