diff options
author | 2018-04-06 17:17:22 -0700 | |
---|---|---|
committer | 2018-04-06 17:19:59 -0700 | |
commit | 5e11bbacaffdf7bc4a9363301de6a0755f95e9c0 (patch) | |
tree | 48f37585cd3b01c71eaced8724be21151374264d /tensorflow | |
parent | ddf54d1c24a2b4dcfd8eb52d21dc1f393785f1e9 (diff) |
Open sourcing proto/rpc ops.
PiperOrigin-RevId: 191962572
Diffstat (limited to 'tensorflow')
45 files changed, 4394 insertions, 0 deletions
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index 02c456c199..8e83b4e176 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -354,6 +354,9 @@ tensorflow/contrib/periodic_resample tensorflow/contrib/periodic_resample/python tensorflow/contrib/periodic_resample/python/ops tensorflow/contrib/predictor +tensorflow/contrib/proto +tensorflow/contrib/proto/python +tensorflow/contrib/proto/python/ops tensorflow/contrib/quantization tensorflow/contrib/quantization/python tensorflow/contrib/quantize @@ -382,6 +385,9 @@ tensorflow/contrib/rnn/ops tensorflow/contrib/rnn/python tensorflow/contrib/rnn/python/kernel_tests tensorflow/contrib/rnn/python/ops +tensorflow/contrib/rpc +tensorflow/contrib/rpc/python +tensorflow/contrib/rpc/python/ops tensorflow/contrib/saved_model tensorflow/contrib/saved_model/python tensorflow/contrib/saved_model/python/saved_model diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 092a48bc6b..e558691de4 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -25,6 +25,8 @@ set(tf_op_lib_names "cudnn_rnn_ops" "data_flow_ops" "dataset_ops" + "decode_proto_ops" + "encode_proto_ops" "functional_ops" "image_ops" "io_ops" @@ -40,6 +42,7 @@ set(tf_op_lib_names "random_ops" "remote_fused_graph_ops" "resource_variable_ops" + "rpc_ops" "script_ops" "sdca_ops" "set_ops" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index fae45ead5c..1a5ec34844 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -330,6 +330,8 @@ GENERATE_PYTHON_OP_LIB("ctc_ops") GENERATE_PYTHON_OP_LIB("cudnn_rnn_ops") GENERATE_PYTHON_OP_LIB("data_flow_ops") GENERATE_PYTHON_OP_LIB("dataset_ops") +GENERATE_PYTHON_OP_LIB("decode_proto_ops") +GENERATE_PYTHON_OP_LIB("encode_proto_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") @@ -343,6 +345,7 @@ GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py) GENERATE_PYTHON_OP_LIB("resource_variable_ops") +GENERATE_PYTHON_OP_LIB("rpc_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") GENERATE_PYTHON_OP_LIB("set_ops") diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index b6acf71b9d..0bc4c5d473 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -301,3 +301,5 @@ tensorflow/core/kernels/warn_about_ints.cc tensorflow/core/kernels/segment_reduction_ops.cc tensorflow/core/kernels/batch_util.cc tensorflow/core/ops/audio_ops.cc +tensorflow/core/kernels/decode_proto_op.cc +tensorflow/core/kernels/encode_proto_op.cc diff --git a/tensorflow/contrib/proto/BUILD b/tensorflow/contrib/proto/BUILD new file mode 100644 index 0000000000..046652cbc5 --- /dev/null +++ b/tensorflow/contrib/proto/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "proto", + srcs = [ + "__init__.py", + ], + deps = [ + "//tensorflow/contrib/proto/python/ops:decode_proto_op_py", + "//tensorflow/contrib/proto/python/ops:encode_proto_op_py", + ], +) diff --git a/tensorflow/contrib/proto/__init__.py b/tensorflow/contrib/proto/__init__.py new file mode 100644 index 0000000000..bc5a49de78 --- /dev/null +++ b/tensorflow/contrib/proto/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops and modules related to proto. + +@@decode_proto +@@encode_proto +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.proto.python.ops.decode_proto_op import decode_proto +from tensorflow.contrib.proto.python.ops.encode_proto_op import encode_proto + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/proto/python/ops/BUILD b/tensorflow/contrib/proto/python/ops/BUILD new file mode 100644 index 0000000000..f17065477e --- /dev/null +++ b/tensorflow/contrib/proto/python/ops/BUILD @@ -0,0 +1,44 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", +) + +py_library( + name = "decode_proto_op_py", + srcs = ["decode_proto_op.py"], + deps = [ + ":gen_decode_proto_op_py", + "//tensorflow/python:framework_ops", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_decode_proto_op_py", + out = "gen_decode_proto_op.py", + deps = [ + "//tensorflow/core:decode_proto_ops_op_lib", + ], +) + +py_library( + name = "encode_proto_op_py", + srcs = ["encode_proto_op.py"], + deps = [ + ":gen_encode_proto_op_py", + "//tensorflow/python:framework_ops", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_encode_proto_op_py", + out = "gen_encode_proto_op.py", + deps = [ + "//tensorflow/core:encode_proto_ops_op_lib", + ], +) diff --git a/tensorflow/contrib/proto/python/ops/decode_proto_op.py b/tensorflow/contrib/proto/python/ops/decode_proto_op.py new file mode 100644 index 0000000000..7dc000ebe4 --- /dev/null +++ b/tensorflow/contrib/proto/python/ops/decode_proto_op.py @@ -0,0 +1,25 @@ +# ============================================================================= +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# pylint: disable=wildcard-import,unused-import +"""Protocol Buffer decoding from tensors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.proto.python.ops.gen_decode_proto_op import decode_proto_v2 as decode_proto +from tensorflow.python.framework import ops +ops.NotDifferentiable("DecodeProtoV2") diff --git a/tensorflow/contrib/proto/python/ops/encode_proto_op.py b/tensorflow/contrib/proto/python/ops/encode_proto_op.py new file mode 100644 index 0000000000..ac12198b2e --- /dev/null +++ b/tensorflow/contrib/proto/python/ops/encode_proto_op.py @@ -0,0 +1,25 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# pylint: disable=wildcard-import,unused-import +"""Protocol Buffer encoding from tensors.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.proto.python.ops.gen_encode_proto_op import encode_proto +from tensorflow.python.framework import ops + +ops.NotDifferentiable("EncodeProto") diff --git a/tensorflow/contrib/rpc/BUILD b/tensorflow/contrib/rpc/BUILD new file mode 100644 index 0000000000..597f18c771 --- /dev/null +++ b/tensorflow/contrib/rpc/BUILD @@ -0,0 +1,13 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "rpc", + srcs = [ + "__init__.py", + ], + deps = ["//tensorflow/contrib/rpc/python/ops:rpc_op_py"], +) diff --git a/tensorflow/contrib/rpc/__init__.py b/tensorflow/contrib/rpc/__init__.py new file mode 100644 index 0000000000..c65c1a05de --- /dev/null +++ b/tensorflow/contrib/rpc/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops and modules related to RPC. + +@@rpc +@@try_rpc +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.rpc.python.ops.rpc_op import rpc +from tensorflow.contrib.rpc.python.ops.rpc_op import try_rpc + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/rpc/python/ops/BUILD b/tensorflow/contrib/rpc/python/ops/BUILD new file mode 100644 index 0000000000..84d2a1832f --- /dev/null +++ b/tensorflow/contrib/rpc/python/ops/BUILD @@ -0,0 +1,24 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +py_library( + name = "rpc_op_py", + srcs = ["rpc_op.py"], + deps = [ + ":gen_rpc_op_py", + "//tensorflow/python:framework_ops", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_rpc_op_py", + out = "gen_rpc_op.py", + deps = [ + "//tensorflow/core:rpc_ops_op_lib", + ], +) diff --git a/tensorflow/contrib/rpc/python/ops/rpc_op.py b/tensorflow/contrib/rpc/python/ops/rpc_op.py new file mode 100644 index 0000000000..e1b6c41137 --- /dev/null +++ b/tensorflow/contrib/rpc/python/ops/rpc_op.py @@ -0,0 +1,26 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +# pylint: disable=wildcard-import,unused-import +"""RPC communication.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.rpc.python.ops.gen_rpc_op import rpc +from tensorflow.contrib.rpc.python.ops.gen_rpc_op import try_rpc +from tensorflow.python.framework import ops +ops.NotDifferentiable("Rpc") +ops.NotDifferentiable("TryRpc") diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 7d5ae1c5b5..1eebeb3995 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -637,6 +637,8 @@ tf_gen_op_libs( "ctc_ops", "data_flow_ops", "dataset_ops", + "decode_proto_ops", + "encode_proto_ops", "function_ops", "functional_ops", "image_ops", @@ -653,6 +655,7 @@ tf_gen_op_libs( "random_ops", "remote_fused_graph_ops", "resource_variable_ops", + "rpc_ops", "scoped_allocator_ops", "sdca_ops", "set_ops", @@ -751,6 +754,8 @@ cc_library( ":cudnn_rnn_ops_op_lib", ":data_flow_ops_op_lib", ":dataset_ops_op_lib", + ":decode_proto_ops_op_lib", + ":encode_proto_ops_op_lib", ":function_ops_op_lib", ":functional_ops_op_lib", ":image_ops_op_lib", @@ -767,6 +772,7 @@ cc_library( ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", ":resource_variable_ops_op_lib", + ":rpc_ops_op_lib", ":scoped_allocator_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", @@ -893,6 +899,8 @@ cc_library( "//tensorflow/core/kernels:cudnn_rnn_kernels", "//tensorflow/core/kernels:data_flow", "//tensorflow/core/kernels:dataset_ops", + "//tensorflow/core/kernels:decode_proto_op", + "//tensorflow/core/kernels:encode_proto_op", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", @@ -914,6 +922,7 @@ cc_library( "//tensorflow/core/kernels:remote_fused_graph_ops", "//tensorflow/core/kernels:required", "//tensorflow/core/kernels:resource_variable_ops", + "//tensorflow/core/kernels:rpc_op", "//tensorflow/core/kernels:scoped_allocator_ops", "//tensorflow/core/kernels:sdca_ops", "//tensorflow/core/kernels:set_kernels", diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt new file mode 100644 index 0000000000..c8152f53c4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt @@ -0,0 +1,116 @@ +op { + graph_op_name: "DecodeProtoV2" + in_arg { + name: "bytes" + description: <<END +Tensor of serialized protos with shape `batch_shape`. +END + } + out_arg { + name: "sizes" + description: <<END +Tensor of int32 with shape `[batch_shape, len(field_names)]`. +Each entry is the number of values found for the corresponding field. +Optional fields may have 0 or 1 values. +END + } + out_arg { + name: "values" + description: <<END +List of tensors containing values for the corresponding field. +`values[i]` has datatype `output_types[i]` +and shape `[batch_shape, max(sizes[...,i])]`. +END + } + attr { + name: "message_type" + description: <<END +Name of the proto message type to decode. +END + } + attr { + name: "field_names" + description: <<END +List of strings containing proto field names. +END + } + attr { + name: "output_types" + description: <<END +List of TF types to use for the respective field in field_names. +END + } + attr { + name: "descriptor_source" + description: <<END +Either the special value `local://` or a path to a file containing +a serialized `FileDescriptorSet`. +END + } + attr { + name: "message_format" + description: <<END +Either `binary` or `text`. +END + } + attr { + name: "sanitize" + description: <<END +Whether to sanitize the result or not. +END + } + summary: <<END +The op extracts fields from a serialized protocol buffers message into tensors. +END + description: <<END +The `decode_proto` op extracts fields from a serialized protocol buffers +message into tensors. The fields in `field_names` are decoded and converted +to the corresponding `output_types` if possible. + +A `message_type` name must be provided to give context for the field +names. The actual message descriptor can be looked up either in the +linked-in descriptor pool or a filename provided by the caller using +the `descriptor_source` attribute. + +Each output tensor is a dense tensor. This means that it is padded to +hold the largest number of repeated elements seen in the input +minibatch. (The shape is also padded by one to prevent zero-sized +dimensions). The actual repeat counts for each example in the +minibatch can be found in the `sizes` output. In many cases the output +of `decode_proto` is fed immediately into tf.squeeze if missing values +are not a concern. When using tf.squeeze, always pass the squeeze +dimension explicitly to avoid surprises. + +For the most part, the mapping between Proto field types and +TensorFlow dtypes is straightforward. However, there are a few +special cases: + +- A proto field that contains a submessage or group can only be converted +to `DT_STRING` (the serialized submessage). This is to reduce the +complexity of the API. The resulting string can be used as input +to another instance of the decode_proto op. + +- TensorFlow lacks support for unsigned integers. The ops represent uint64 +types as a `DT_INT64` with the same twos-complement bit pattern +(the obvious way). Unsigned int32 values can be represented exactly by +specifying type `DT_INT64`, or using twos-complement if the caller +specifies `DT_INT32` in the `output_types` attribute. + +The `descriptor_source` attribute selects a source of protocol +descriptors to consult when looking up `message_type`. This may be a +filename containing a serialized `FileDescriptorSet` message, +or the special value `local://`, in which case only descriptors linked +into the code will be searched; the filename can be on any filesystem +accessible to TensorFlow. + +You can build a `descriptor_source` file using the `--descriptor_set_out` +and `--include_imports` options to the protocol compiler `protoc`. + +The `local://` database only covers descriptors linked into the +code via C++ libraries, not Python imports. You can link in a proto descriptor +by creating a cc_library target with alwayslink=1. + +Both binary and text proto serializations are supported, and can be +chosen using the `format` attribute. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt new file mode 100644 index 0000000000..fdbe47f236 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_EncodeProto.pbtxt @@ -0,0 +1,81 @@ +op { + graph_op_name: "EncodeProto" + in_arg { + name: "sizes" + description: <<END +Tensor of int32 with shape `[batch_shape, len(field_names)]`. +END + } + in_arg { + name: "values" + description: <<END +List of tensors containing values for the corresponding field. +END + } + out_arg { + name: "bytes" + description: <<END +Tensor of serialized protos with shape `batch_shape`. +END + } + attr { + name: "message_type" + description: <<END +Name of the proto message type to decode. +END + } + attr { + name: "field_names" + description: <<END +List of strings containing proto field names. +END + } + attr { + name: "Tinput_types" + description: <<END +The input types. +END + } + summary: <<END +The op serializes protobuf messages provided in the input tensors. +END + description: <<END +The types of the tensors in `values` must match the schema for the +fields specified in `field_names`. All the tensors in `values` must +have a common shape prefix, *batch_shape*. + +The `sizes` tensor specifies repeat counts for each field. The repeat +count (last dimension) of a each tensor in `values` must be greater +than or equal to corresponding repeat count in `sizes`. + +A `message_type` name must be provided to give context for the field +names. The actual message descriptor can be looked up either in the +linked-in descriptor pool or a filename provided by the caller using +the `descriptor_source` attribute. + +The `descriptor_source` attribute selects a source of protocol +descriptors to consult when looking up `message_type`. This may be a +filename containing a serialized `FileDescriptorSet` message, +or the special value `local://`, in which case only descriptors linked +into the code will be searched; the filename can be on any filesystem +accessible to TensorFlow. + +You can build a `descriptor_source` file using the `--descriptor_set_out` +and `--include_imports` options to the protocol compiler `protoc`. + +The `local://` database only covers descriptors linked into the +code via C++ libraries, not Python imports. You can link in a proto descriptor +by creating a cc_library target with alwayslink=1. + +There are a few special cases in the value mapping: + +Submessage and group fields must be pre-serialized as TensorFlow strings. + +TensorFlow lacks support for unsigned int64s, so they must be +represented as `tf.int64` with the same twos-complement bit pattern +(the obvious way). + +Unsigned int32 values can be represented exactly with `tf.int64`, or +with sign wrapping if the input is of type `tf.int32`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt new file mode 100644 index 0000000000..344ef191fd --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Rpc.pbtxt @@ -0,0 +1,108 @@ +op { + graph_op_name: "Rpc" + in_arg { + name: "address" + description: <<END +`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `method` and `request`. +END + } + in_arg { + name: "method" + description: <<END +`0-D` or `1-D`. The method address on the RPC server. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `address` and `request`. +END + } + in_arg { + name: "request" + description: <<END +`0-D` or `1-D`. Serialized proto strings: the rpc request argument. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `address` and `method`. +END + } + out_arg { + name: "response" + description: <<END +Same shape as `request`. Serialized proto strings: the rpc responses. +END + } + attr { + name: "protocol" + description: <<END +RPC protocol to use. Empty string means use the default protocol. +Options include 'grpc'. +END + } + attr { + name: "fail_fast" + description: <<END +`boolean`. If `true` (default), then failures to connect +(i.e., the server does not immediately respond) cause an RPC failure. +END + } + attr { + name: "timeout_in_ms" + description: <<END +`int`. If `0` (default), then the kernel will run the RPC +request and only time out if the RPC deadline passes or the session times out. +If this value is greater than `0`, then the op will raise an exception if +the RPC takes longer than `timeout_in_ms`. +END + } + summary: <<END +Perform batches of RPC requests. +END + description: <<END +This op asynchronously performs either a single RPC request, or a batch +of requests. RPC requests are defined by three main parameters: + + - `address` (the host+port or BNS address of the request) + - `method` (the RPC method name for the request) + - `request` (the serialized proto string, or vector of strings, + of the RPC request argument). + +For example, if you have an RPC service running on port localhost:2345, +and its interface is configured with the following proto declaration: + +``` +service MyService { + rpc MyMethod(MyRequestProto) returns (MyResponseProto) { + } +}; +``` + +then call this op with arguments: + +``` +address = "localhost:2345" +method = "MyService/MyMethod" +``` + +The `request` tensor is a string tensor representing serialized `MyRequestProto` +strings; and the output string tensor `response` will have the same shape +and contain (upon successful completion) corresponding serialized +`MyResponseProto` strings. + +For example, to send a single, empty, `MyRequestProto`, call +this op with `request = ""`. To send 5 **parallel** empty requests, +call this op with `request = ["", "", "", "", ""]`. + +More generally, one can create a batch of `MyRequestProto` serialized protos +from regular batched tensors using the `encode_proto` op, and convert +the response `MyResponseProto` serialized protos to batched tensors +using the `decode_proto` op. + +**NOTE** Working with serialized proto strings is faster than instantiating +actual proto objects in memory, so no performance degradation is expected +compared to writing custom kernels for this workflow. + +If the connection fails or the remote worker returns an error +status, the op reraises this exception locally. + +See the `TryRpc` op if you prefer to handle RPC failures manually in the graph. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt new file mode 100644 index 0000000000..bded00e83c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TryRpc.pbtxt @@ -0,0 +1,123 @@ +op { + graph_op_name: "TryRpc" + in_arg { + name: "address" + description: <<END +`0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `method` and `request`. +END + } + in_arg { + name: "method" + description: <<END +`0-D` or `1-D`. The method address on the RPC server. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `address` and `request`. +END + } + in_arg { + name: "request" + description: <<END +`0-D` or `1-D`. Serialized proto strings: the rpc request argument. +If this tensor has more than 1 element, then multiple parallel rpc requests +are sent. This argument broadcasts with `address` and `method`. +END + } + out_arg { + name: "response" + description: <<END +Same shape as `request`. Serialized proto strings: the rpc responses. +END + } + out_arg { + name: "status_code" + description: <<END +Same shape as `request`. Values correspond to tensorflow Status enum codes. +END + } + out_arg { + name: "status_message" + description: <<END +Same shape as `request`. Values correspond to Status messages +returned from the RPC calls. +END + } + attr { + name: "protocol" + description: <<END +RPC protocol to use. Empty string means use the default protocol. +Options include 'grpc'. +END + } + attr { + name: "fail_fast" + description: <<END +`boolean`. If `true` (default), then failures to connect +(i.e., the server does not immediately respond) cause an RPC failure. +END + } + attr { + name: "timeout_in_ms" + description: <<END +`int`. If `0` (default), then the kernel will run the RPC +request and only time out if the RPC deadline passes or the session times out. +If this value is greater than `0`, then the op will raise an exception if +the RPC takes longer than `timeout_in_ms`. +END + } + summary: <<END +Perform batches of RPC requests. +END + description: <<END +This op asynchronously performs either a single RPC request, or a batch +of requests. RPC requests are defined by three main parameters: + + - `address` (the host+port or BNS address of the request) + - `method` (the method name for the request) + - `request` (the serialized proto string, or vector of strings, + of the RPC request argument). + +For example, if you have an RPC service running on port localhost:2345, +and its interface is configured with the following proto declaration: + +``` +service MyService { + rpc MyMethod(MyRequestProto) returns (MyResponseProto) { + } +}; +``` + +then call this op with arguments: + +``` +address = "localhost:2345" +method = "MyService/MyMethod" +``` + +The `request` tensor is a string tensor representing serialized `MyRequestProto` +strings; and the output string tensor `response` will have the same shape +and contain (upon successful completion) corresponding serialized +`MyResponseProto` strings. + +For example, to send a single, empty, `MyRequestProto`, call +this op with `request = ""`. To send 5 **parallel** empty requests, +call this op with `request = ["", "", "", "", ""]`. + +More generally, one can create a batch of `MyRequestProto` serialized protos +from regular batched tensors using the `encode_proto` op, and convert +the response `MyResponseProto` serialized protos to batched tensors +using the `decode_proto` op. + +**NOTE** Working with serialized proto strings is faster than instantiating +actual proto objects in memory, so no performance degradation is expected +compared to writing custom kernels for this workflow. + +Unlike the standard `Rpc` op, if the connection fails or the remote worker +returns an error status, this op does **not** reraise the exception. +Instead, the `status_code` and `status_message` entry for the corresponding RPC +call is set with the error returned from the RPC call. The `response` tensor +will contain valid response values for those minibatch entries whose RPCs did +not fail; the rest of the entries will have empty strings. +END +} diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 9c655bfa31..d3478dfc38 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -499,3 +499,33 @@ tf_cuda_cc_test( "//tensorflow/core/kernels:variable_ops", ], ) + +cc_library( + name = "grpc_rpc_factory", + srcs = [ + "grpc_rpc_factory.cc", + ], + hdrs = ["grpc_rpc_factory.h"], + deps = [ + ":grpc_state", + ":grpc_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/util/rpc:call_container", + "//tensorflow/core/util/rpc:rpc_factory", + ], +) + +cc_library( + name = "grpc_rpc_factory_registration", + srcs = [ + "grpc_rpc_factory_registration.cc", + ], + deps = [ + ":grpc_rpc_factory", + "//tensorflow/core/util/rpc:rpc_factory", + "//tensorflow/core/util/rpc:rpc_factory_registry", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc new file mode 100644 index 0000000000..d004abd1c1 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc @@ -0,0 +1,213 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/util/rpc/call_container.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" + +#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h" + +namespace tensorflow { + +namespace { +class GrpcCall { + public: + explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc, + const string* request_msg, string* response_msg, + int32* status_code, string* status_message) + : container_(container), + index_(index), + try_rpc_(try_rpc), + request_msg_(request_msg), + response_msg_(response_msg), + status_code_(status_code), + status_message_(status_message) {} + + void StartCancel() { call_opts_.StartCancel(); } + + void Done(const Status& s) { + DCHECK(container_ != nullptr); + if (!s.ok() && try_rpc_) { + DCHECK(status_code_ != nullptr); + DCHECK(status_message_ != nullptr); + *status_code_ = s.code(); + *status_message_ = s.error_message(); + } + container_->Done(s, index_); + } + + const string& request() const { return *request_msg_; } + string* response() const { return response_msg_; } + CallOptions* call_opts() { return &call_opts_; } + + private: + CallContainer<GrpcCall>* const container_; + const int index_; + bool try_rpc_; + CallOptions call_opts_; + const string* request_msg_; + string* response_msg_; + int* status_code_; + string* status_message_; +}; + +} // namespace + +GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms) + : RPCFactory(), fail_fast_(fail_fast), timeout_in_ms_(timeout_in_ms) { + // TODO(ebrevdo): Investigate possible performance improvements by + // replacing this thread with a threadpool. + polling_thread_ = + ctx->env()->StartThread(ThreadOptions(), "rpc_op_grpc_factory", [this]() { + void* tag; + bool ok; + while (completion_queue_.Next(&tag, &ok)) { + GrpcClientCQTag* callback_tag = static_cast<GrpcClientCQTag*>(tag); + callback_tag->OnCompleted(ok); + } + }); +} + +GrpcRPCFactory::~GrpcRPCFactory() { + // The amount of time we wait depends on several parameters, including: + // - the value of the fail_fast attribute. + // - the timeout option of the rpc call in the proto declaration. + // - the network roundtrip time and service's execution time. + // + // If a connection is made but the service doesn't ever respond, and + // there is no timeout option set for this rpc call, then it is + // possible the RPC request will wait forever. + // + completion_queue_.Shutdown(); + delete polling_thread_; +} + +void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements, + const Tensor& address_t, const Tensor& method_t, + const Tensor& request_t, const bool try_rpc, + Tensor* response_t, Tensor* status_code_t, + Tensor* status_message_t, + AsyncOpKernel::DoneCallback done) { + auto address = address_t.flat<string>(); + auto method = method_t.flat<string>(); + auto request = request_t.flat<string>(); + + // Stubs are maintained by the GrpcRPCFactory class and will be + // deleted when the class is destroyed. + ::grpc::GenericStub* singleton_stub = nullptr; + if (address.size() == 1) { + singleton_stub = GetOrCreateStubForAddress(address(0)); + } + auto get_stub = [&address, this, + singleton_stub](int64 ix) -> ::grpc::GenericStub* { + return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix)) + : singleton_stub; + }; + auto get_method_ptr = [&method](int64 ix) -> const string* { + return (method.size() > 1) ? &(method(ix)) : &(method(0)); + }; + auto get_request_ptr = [&request](int64 ix) -> const string* { + return (request.size() > 1) ? &(request(ix)) : &(request(0)); + }; + + if (try_rpc) { + // In this case status_code will never be set in the response, + // so we just set it to OK. + DCHECK(status_code_t != nullptr); + status_code_t->flat<int32>().setConstant( + static_cast<int>(errors::Code::OK)); + } + + CancellationManager* cm = ctx->cancellation_manager(); + CancellationToken cancellation_token = cm->get_cancellation_token(); + + // This object will delete itself when done. + auto* container = + new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc, + std::move(done), cancellation_token); + + auto response = response_t->flat<string>(); + int32* status_code_ptr = nullptr; + string* status_message_ptr = nullptr; + if (try_rpc) { + status_code_ptr = status_code_t->flat<int32>().data(); + status_message_ptr = status_message_t->flat<string>().data(); + } + for (int i = 0; i < num_elements; ++i) { + container->calls()->emplace_back( + container, i, try_rpc, get_request_ptr(i), &response(i), + (try_rpc) ? &status_code_ptr[i] : nullptr, + (try_rpc) ? &status_message_ptr[i] : nullptr); + } + + int i = 0; + for (GrpcCall& call : *(container->calls())) { + // This object will delete itself when done. + new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i), + call.request(), call.response(), + /*done=*/[&call](const Status& s) { call.Done(s); }, + call.call_opts(), fail_fast_, timeout_in_ms_); + ++i; + } + + // Need to register this callback after all the RPCs are in + // flight; otherwise we may try to cancel an RPC *before* it + // launches, which is a no-op, and then fall into a deadlock. + bool is_cancelled = !cm->RegisterCallback( + cancellation_token, [container]() { container->StartCancel(); }); + + if (is_cancelled) { + ctx->SetStatus(errors::Cancelled("Operation has been cancelled.")); + // container's reference counter will take care of calling done(). + container->StartCancel(); + } +} + +::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress( + const string& address) { + mutex_lock lock(mu_); + + auto stub = stubs_.find(address); + if (stub != stubs_.end()) return stub->second.get(); + + ChannelPtr channel = CreateChannelForAddress(address); + auto* created = new ::grpc::GenericStub(channel); + stubs_[address].reset(created); + return created; +} + +GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress( + const string& address) { + ::grpc::ChannelArguments args; + args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max()); + + // Set a standard backoff timeout of 1s instead of the + // (sometimes default) 20s. + args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000); + return ::grpc::CreateCustomChannel( + /*target=*/address, ::grpc::InsecureChannelCredentials(), args); +} + +} // namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h new file mode 100644 index 0000000000..34ec235aaf --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h @@ -0,0 +1,59 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_state.h" +#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" + +namespace tensorflow { + +class GrpcRPCFactory : public RPCFactory { + public: + explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms); + + // Explicit destructor to control destruction order. + ~GrpcRPCFactory() override; + + void Call(OpKernelContext* ctx, int64 num_elements, const Tensor& address_t, + const Tensor& method_t, const Tensor& request_t, const bool try_rpc, + Tensor* response_t, Tensor* status_code_t, Tensor* status_message_t, + AsyncOpKernel::DoneCallback done) override; + + protected: + typedef std::shared_ptr<::grpc::Channel> ChannelPtr; + virtual ChannelPtr CreateChannelForAddress(const string& address); + + private: + ::grpc::GenericStub* GetOrCreateStubForAddress(const string& address); + + bool fail_fast_; + int64 timeout_in_ms_; + ::grpc::CompletionQueue completion_queue_; + Thread* polling_thread_; // Owned. + + mutex mu_; + typedef std::unique_ptr<::grpc::GenericStub> StubPtr; + std::unordered_map<string, StubPtr> stubs_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RPC_FACTORY_H_ diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc new file mode 100644 index 0000000000..b884489378 --- /dev/null +++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory_registration.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" + +namespace tensorflow { +namespace { + +// Used for adding the grpc factory to the RPC factory registry. +struct Value { + static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms) { + return new GrpcRPCFactory(ctx, fail_fast, timeout_in_ms); + } +}; + +REGISTER_RPC_FACTORY("grpc", Value::Function); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1857d8d655..783de6af88 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -5121,6 +5121,9 @@ filegroup( "summary_interface.*", "summary_kernels.*", "spectrogram_convert_test_data.cc", + "decode_proto_op.cc", + "encode_proto_op.cc", + "rpc_op.cc", # Excluded due to experimental status: "debug_ops.*", "scatter_nd_op*", @@ -6153,6 +6156,50 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "decode_proto_op", + srcs = [ + "decode_proto_op.cc", + ], + deps = [ + "//tensorflow/core:decode_proto_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/util/proto:decode", + "//tensorflow/core/util/proto:descriptors", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "encode_proto_op", + srcs = ["encode_proto_op.cc"], + deps = [ + "//tensorflow/core:encode_proto_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/util/proto:descriptors", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "rpc_op", + srcs = [ + "rpc_op.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:rpc_ops_op_lib", + "//tensorflow/core/util/rpc:call_container", + "//tensorflow/core/util/rpc:rpc_factory", + "//tensorflow/core/util/rpc:rpc_factory_registry", + "//third_party/eigen3", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc new file mode 100644 index 0000000000..b4e5b776ed --- /dev/null +++ b/tensorflow/core/kernels/decode_proto_op.cc @@ -0,0 +1,1011 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// DecodeProto is a TensorFlow Op which extracts arbitrary fields +// from protos serialized as strings. +// +// See docs in ../ops/decode_proto_op.cc. +// +// This implementation reads the serialized format using a handful of +// calls from the WireFormatLite API used by generated proto code. +// WireFormatLite is marked as an "internal" proto API but is widely +// used in practice and highly unlikely to change. +// This will be much faster than the previous implementation based on +// constructing a temporary dynamic message in memory and using the +// proto reflection api to read it. +// It can be used with any proto whose descriptors are available at +// runtime but should be competitive in speed with approaches that +// compile in the proto definitions. + +#include <memory> +#include <string> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/proto/decode.h" +#include "tensorflow/core/util/proto/descriptors.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::MakeUnique; +using ::tensorflow::protobuf::Descriptor; +using ::tensorflow::protobuf::DescriptorPool; +using ::tensorflow::protobuf::DynamicMessageFactory; +using ::tensorflow::protobuf::FieldDescriptor; +using ::tensorflow::protobuf::Message; +using ::tensorflow::protobuf::TextFormat; +using ::tensorflow::protobuf::internal::WireFormatLite; +using ::tensorflow::protobuf::io::CodedInputStream; + +const bool kFailOnDecodeError = true; + +// Returns true if the proto field type can be converted to the +// tensorflow::DataType. +bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) { + switch (field_type) { + case WireFormatLite::TYPE_DOUBLE: + return output_type == tensorflow::DT_DOUBLE; + case WireFormatLite::TYPE_FLOAT: + return output_type == tensorflow::DT_FLOAT || + output_type == tensorflow::DT_DOUBLE; + case WireFormatLite::TYPE_INT64: + return output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_UINT64: + return output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_INT32: + return output_type == tensorflow::DT_INT32; + case WireFormatLite::TYPE_FIXED64: + return output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_FIXED32: + return output_type == tensorflow::DT_INT32 || + output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_BOOL: + return output_type == tensorflow::DT_BOOL; + case WireFormatLite::TYPE_STRING: + return output_type == tensorflow::DT_STRING; + case WireFormatLite::TYPE_GROUP: + return output_type == tensorflow::DT_STRING; + case WireFormatLite::TYPE_MESSAGE: + return output_type == tensorflow::DT_STRING; + case WireFormatLite::TYPE_BYTES: + return output_type == tensorflow::DT_STRING; + case WireFormatLite::TYPE_UINT32: + return output_type == tensorflow::DT_INT32 || + output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_ENUM: + return output_type == tensorflow::DT_INT32; + case WireFormatLite::TYPE_SFIXED32: + return output_type == tensorflow::DT_INT32; + case WireFormatLite::TYPE_SFIXED64: + return output_type == tensorflow::DT_INT64; + case WireFormatLite::TYPE_SINT32: + return output_type == tensorflow::DT_INT32; + case WireFormatLite::TYPE_SINT64: + return output_type == tensorflow::DT_INT64; + // default: intentionally omitted in order to enable static checking. + } +} + +// A FieldInfo holds a handful of information from the FieldDescriptor +// and user attributes. +struct FieldInfo { + FieldInfo(const FieldDescriptor* field_desc, int user_index) + : output_index(user_index) { + // Without this intermediate data structure, the profile had hotspots + // calling methods of FieldDescriptor. + number = field_desc->number(); + + // The wire format library defines the same constants used in + // descriptor.proto. This static_cast is safe because they + // are guaranteed to stay in sync. + // We need the field type from the FieldDescriptor here + // because the wire format doesn't tell us anything about + // what happens inside a packed repeated field: there is + // enough information in the wire format to skip the + // whole field but not enough to know how to parse what's + // inside. For that we go to the schema. + type = static_cast<WireFormatLite::FieldType>(field_desc->type()); + is_repeated = field_desc->is_repeated(); + } + + // Disable copy and move. + FieldInfo(const FieldInfo&) = delete; + FieldInfo& operator=(const FieldInfo&) = delete; + + // Internally we sort field descriptors by wire number for + // fast lookup. In general this is different from the order + // given by the user. Output_index gives the index into + // the field_names and output_types attributes and into + // the output tensor list. + int output_index = -1; + + // This is a cache of the relevant fields from `FieldDescriptorProto`. + // This was added after noticing that FieldDescriptor->type() was + // using 6% of the cpu profile. + WireFormatLite::FieldType type; + int number; + bool is_repeated; +}; + +// A CountCollector counts sizes of repeated and optional fields in a proto. +// +// Each field is tracked by a single CountCollector instance. The +// instance manages a single count, which is stored as a pointer (it +// is intended to be a reference to the `sizes` output which is being +// filled in). The pointer is passed in at initialization. +// +// Counting is done as a separate pass in order to allocate output tensors +// all at once. This allows the TensorFlow runtime to optimize allocation +// for the consumer, while removing the need for copying inside this op. +// After this pass, the DenseCollector class (below) gathers the data: +// It is more complex and provides better motivation for the API here. +class CountCollector { + public: + // Default constructor allows the collector to be a vector element. + CountCollector() = default; + + // The count may be stored inside an Eigen Tensor to eliminate copying. + explicit CountCollector(int32* count) : count_ptr_(count) {} + + // Reads (in this case counts) a single value. + Status ReadValue(CodedInputStream* input, const FieldInfo& field) { + // Only repeated fields can have count > 1. + if (*count_ptr_ == 0 || field.is_repeated) { + (*count_ptr_)++; + } + // We expect a wire type based on the schema field_type, to allow + // a little more checking. + if (!SkipValue(input, field)) { + return errors::DataLoss("ReadValue: Failed skipping field when counting"); + } + return Status::OK(); + } + + // Reads (in this case counts) a length-delimited list of values. + Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, + size_t buf_size) { + if (buf_size == 0) { + return Status::OK(); + } + + const void* tmpbuf; + int unused_max_buf_size; + + input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size); + // This is safe because the underlying storage for the CodedInputStream is + // owned by the input tensor. If it were a Cord or file-backed stream this + // pointer would go stale after the bytes were skipped. + const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf); + + // Important: we skipped the input->{Push,Pop}Limit() calls for speed, + // so the bounds check on buf_size inside Skip() is critical, and + // must be done before scanning the contents. + if (!input->Skip(buf_size)) { + return errors::DataLoss("ReadPackedValues: Skipping packed field failed"); + } + + // Dispatch to the appropriately typed field reader based on the + // schema type. + Status st; + switch (field.type) { + case WireFormatLite::TYPE_DOUBLE: + st = CountPackedFixed<double>(buf, buf_size); + break; + case WireFormatLite::TYPE_FLOAT: + st = CountPackedFixed<float>(buf, buf_size); + break; + case WireFormatLite::TYPE_INT64: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_UINT64: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_INT32: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_FIXED64: + st = CountPackedFixed<uint64>(buf, buf_size); + break; + case WireFormatLite::TYPE_FIXED32: + st = CountPackedFixed<uint32>(buf, buf_size); + break; + case WireFormatLite::TYPE_BOOL: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_STRING: + st = errors::DataLoss("TYPE_STRING encountered as packed"); + break; + case WireFormatLite::TYPE_GROUP: + st = errors::DataLoss("TYPE_GROUP encountered as packed"); + break; + case WireFormatLite::TYPE_MESSAGE: + st = errors::DataLoss("TYPE_MESSAGE encountered as packed"); + break; + case WireFormatLite::TYPE_BYTES: + st = errors::DataLoss("TYPE_BYTES encountered as packed"); + break; + case WireFormatLite::TYPE_UINT32: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_ENUM: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_SFIXED32: + st = CountPackedFixed<int32>(buf, buf_size); + break; + case WireFormatLite::TYPE_SFIXED64: + st = CountPackedFixed<int64>(buf, buf_size); + break; + case WireFormatLite::TYPE_SINT32: + st = CountPackedVarint(buf, buf_size); + break; + case WireFormatLite::TYPE_SINT64: + st = CountPackedVarint(buf, buf_size); + break; + // default: intentionally omitted in order to enable static checking. + } + if (!st.ok()) { + return st; + } + + if (!field.is_repeated && *count_ptr_ > 1) { + *count_ptr_ = 1; + } + return Status::OK(); + } + + private: + // Skips a length-delimited value. + static bool SkipBytes(CodedInputStream* input) { + uint32 length; + if (!input->ReadVarint32(&length)) { + return false; + } + return input->Skip(length); + } + + // Counts the number of packed varints in an array. + // The end of a varint is signaled by a value < 0x80, + // so counting them requires parsing the bytestream. + // It is the caller's responsibility to ensure that len > 0. + Status CountPackedVarint(const uint8* buf, size_t len) { + const uint8* bound = buf + len; + int count; + + // The last byte in a valid encoded varint is guaranteed to have + // the high bit unset. We rely on this property to prevent + // ReadVarint64FromArray from going out of bounds, so validate + // the end of the buf before scanning anything. + if (bound[-1] & 0x80) { + return errors::DataLoss("Corrupt packed varint"); + } + + // Now we can trust ReadVarint64FromArray to stay in bounds. + for (count = 0; buf < bound; ++count) { + uint64 temp; + bool ok; + buf = internal::ReadVarint64FromArray(buf, &ok, &temp); + if (!ok) { + return errors::DataLoss("Corrupt packed varint"); + } + } + + *count_ptr_ += count; + return Status::OK(); + } + + // Counts the number of fixed-size values in a packed field. + // This can be done without actually parsing anything. + template <typename T> + Status CountPackedFixed(const uint8* unused_buf, size_t len) { + int count = len / sizeof(T); + if (count * sizeof(T) != len) { + return errors::DataLoss( + "Illegal data length for packed fixed-size type: ", len); + } + *count_ptr_ += len / sizeof(T); + return Status::OK(); + } + + // Skips a single value in the input stream. + // Dispatches to the appropriately typed field skipper based on the + // schema type tag. + // This is not as permissive as just handling the wire type. + static bool SkipValue(CodedInputStream* input, const FieldInfo& field) { + uint32 tmp32; + protobuf_uint64 tmp64; + switch (field.type) { + case WireFormatLite::TYPE_DOUBLE: + return input->ReadLittleEndian64(&tmp64); + case WireFormatLite::TYPE_FLOAT: + return input->ReadLittleEndian32(&tmp32); + case WireFormatLite::TYPE_INT64: + return input->ReadVarint64(&tmp64); + case WireFormatLite::TYPE_UINT64: + return input->ReadVarint64(&tmp64); + case WireFormatLite::TYPE_INT32: + return input->ReadVarint32(&tmp32); + case WireFormatLite::TYPE_FIXED64: + return input->ReadLittleEndian64(&tmp64); + case WireFormatLite::TYPE_FIXED32: + return input->ReadLittleEndian32(&tmp32); + case WireFormatLite::TYPE_BOOL: + return input->ReadVarint32(&tmp32); + case WireFormatLite::TYPE_STRING: + return SkipBytes(input); + case WireFormatLite::TYPE_GROUP: + return WireFormatLite::SkipField( + input, WireFormatLite::MakeTag( + field.number, WireFormatLite::WIRETYPE_START_GROUP)); + case WireFormatLite::TYPE_MESSAGE: + return SkipBytes(input); + case WireFormatLite::TYPE_BYTES: + return SkipBytes(input); + case WireFormatLite::TYPE_UINT32: + return input->ReadVarint32(&tmp32); + case WireFormatLite::TYPE_ENUM: + return input->ReadVarint32(&tmp32); + case WireFormatLite::TYPE_SFIXED32: + return input->ReadLittleEndian32(&tmp32); + case WireFormatLite::TYPE_SFIXED64: + return input->ReadLittleEndian64(&tmp64); + case WireFormatLite::TYPE_SINT32: + return input->ReadVarint32(&tmp32); + case WireFormatLite::TYPE_SINT64: + return input->ReadVarint64(&tmp64); + // default: intentionally omitted in order to enable static checking. + } + } + + int32* count_ptr_ = nullptr; +}; + +// A DenseCollector accumulates values from a proto into a tensor. +// +// There is an instance of DenseCollector for each field of each +// proto. The DenseCollector deserializes the value from the wire +// directly into the preallocated output Tensor. +// +// This class is named DenseCollector because in the future there should +// be a SparseCollector that accumulates field data into sparse tensors if +// the user requests it. +class DenseCollector { + public: + // Default constructor allows the collector to be a vector element. + DenseCollector() = default; + + // A DenseCollector applies to one field of a serialized message. + DenseCollector(uint8* datap, DataType dtype, int max_repeat_count) + : datap_(datap), dtype_(dtype), max_repeat_count_(max_repeat_count) {} + + // Reads a value from the input stream and stores it. + // + // Always inlining gave a ~50% speedup on microbenchmarks at one point. + // TODO(nix): try removing it to see if that still holds. + // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE + Status ReadValue(CodedInputStream* input, const FieldInfo& field) { + // For required and optional fields, we overwrite values[0] with + // the latest one in the wire stream. + // See https://developers.google.com/protocol-buffers/docs/encoding#optional + // Only for repeated fields do we advance the next_repeat_index_ past 1. + // TODO(nix): to handle oneof we must also zero out any previous values + // seen on the wire. + int32 index = 0; + if (field.is_repeated) { + index = next_repeat_index_; + } + next_repeat_index_ = index + 1; + + return internal::ReadValue(input, field.type, field.number, dtype_, index, + datap_); + } + + // Reads and stores a length-delimited list of values. + Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, + const size_t buf_size) { + const void* buf; + int unused_max_buf_size; + input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size); + // This is safe because the underlying storage for the CodedInputStream is + // owned by the input tensor. If it were a Cord or file-backed stream this + // pointer would go stale after the bytes were skipped. + if (!input->Skip(buf_size)) { + return errors::DataLoss( + "ReadPackedValues: Skipping packed field failed. Field tag: ", + field.number); + } + + // Setting stride=0 causes new values to overwrite old ones for + // non-repeated fields. + const int stride = field.is_repeated ? 1 : 0; + + if (next_repeat_index_ >= max_repeat_count_) { + return errors::DataLoss( + "ReadPackedValues: Tried to write more entries than allowed. " + "Field tag: ", + field.number, ", Max entries allowed: ", max_repeat_count_); + } else { + return internal::ReadPackedFromArray(buf, buf_size, field.type, + field.number, dtype_, stride, + &next_repeat_index_, datap_); + } + } + + // Fills in any missing values in the output array with defaults. + // Dispatches to the appropriately typed field default based on the + // runtime type tag. + Status FillWithDefaults() { + switch (dtype_) { + case DataType::DT_FLOAT: + return FillDefault<float>(); + case DataType::DT_DOUBLE: + return FillDefault<double>(); + case DataType::DT_INT32: + return FillDefault<int32>(); + case DataType::DT_UINT8: + return FillDefault<uint8>(); + case DataType::DT_INT8: + return FillDefault<int8>(); + case DataType::DT_STRING: + return FillDefault<string>(); + case DataType::DT_INT64: + return FillDefault<int64>(); + case DataType::DT_BOOL: + return FillDefault<bool>(); + default: + // There are many tensorflow dtypes not handled here, but they + // should not come up unless type casting is added to the Op. + // Chaining with tf.cast() should do the right thing until then. + return errors::DataLoss( + "Failed filling defaults in unknown tf::DataType"); + } + } + + private: + // Fills empty values in the dense representation with a + // default value. This uses next_repeat_index_ which counts the number + // of parsed values for the field. + template <class T> + Status FillDefault() { + for (int i = next_repeat_index_; i < max_repeat_count_; i++) { + reinterpret_cast<T*>(datap_)[i] = T(); + } + return Status::OK(); + } + + int32 next_repeat_index_ = 0; + + // This is a pointer to data_[message_index_]. + // There is no bounds checking at this level: we computed the max + // repeat size for each field in CountCollector and use the same + // code to traverse it here, so we are guaranteed not to be called + // for more items than we have allocated space. + void* const datap_ = nullptr; + + const DataType dtype_ = DataType::DT_INVALID; + const int max_repeat_count_ = 0; +}; + +class DecodeProtoOp : public OpKernel { + public: + explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) { + string descriptor_source; + OP_REQUIRES_OK(context, + context->GetAttr("descriptor_source", &descriptor_source)); + + // We always get back a desc_pool, but we may not own it. If we own it, + // owned_desc_pool_ will be filled in. + DescriptorPool const* desc_pool; + OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source, + &desc_pool, &owned_desc_pool_)); + + string message_type; + OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type)); + + const Descriptor* message_desc = + desc_pool->FindMessageTypeByName(message_type); + OP_REQUIRES(context, message_desc != nullptr, + errors::InvalidArgument("No descriptor found for message type ", + message_type)); + + std::vector<string> field_names; + OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names)); + std::vector<DataType> output_types; + OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types)); + OP_REQUIRES( + context, field_names.size() == output_types.size(), + errors::InvalidArgument("field_names and output_types attributes must " + "have the same length")); + + // Gather the field descriptors and check that requested output types match. + + int field_index = 0; + std::vector<const FieldDescriptor*> field_descs; + for (const string& name : field_names) { + auto fd = message_desc->FindFieldByName(name); + OP_REQUIRES(context, fd != nullptr, + errors::InvalidArgument("Unknown field: ", name, + " in message type ", message_type)); + OP_REQUIRES(context, + CheckOutputType(fd->type(), output_types[field_index]), + // Many TensorFlow types don't have corresponding proto types + // and the user will get an error if they are requested. It + // would be nice to allow conversions here, but tf.cast + // already exists so we don't duplicate the functionality. + // Known unhandled types: + // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 + // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 + errors::InvalidArgument("Unexpected output type for ", + fd->full_name(), ": ", fd->cpp_type(), + " to ", output_types[field_index])); + + field_index++; + field_descs.push_back(fd); + } + + // Internally we want the field_descs sorted by their number on the wire. + // But the output tensors are allocated in the order given by the caller. + // Build a mapping i->j, where field_descs[i] corresponds to outputs[j]. + std::vector<int> output_indices; + output_indices.reserve(field_names.size()); + for (int i = 0; i < field_names.size(); i++) { + output_indices.push_back(i); + } + std::sort(output_indices.begin(), output_indices.end(), + [field_descs](int a, int b) { + return field_descs[a]->number() < field_descs[b]->number(); + }); + + // Now store the fields in sorted order. + for (int i = 0; i < field_names.size(); i++) { + fields_.push_back(MakeUnique<FieldInfo>(field_descs[output_indices[i]], + output_indices[i])); + } + + message_prototype_ = message_factory_.GetPrototype(message_desc); + OP_REQUIRES(context, message_prototype_ != nullptr, + errors::InvalidArgument("Couldn't get prototype message: ", + message_desc->full_name())); + string format; + OP_REQUIRES_OK(context, context->GetAttr("message_format", &format)); + OP_REQUIRES( + context, format == "binary" || format == "text", + errors::InvalidArgument("format must be one of binary or text")); + is_binary_ = format == "binary"; + + // Enable the initial protobuf sanitizer, which is much + // more expensive than the decoder. + // TODO(nix): Remove this once the fast decoder + // has passed security review. + OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& buf_tensor = ctx->input(0); + int message_count = buf_tensor.NumElements(); + OP_REQUIRES(ctx, message_count >= 1, + errors::InvalidArgument( + "Bufs argument must contain at least one value")); + + int field_count = fields_.size(); + + // Save the argument shape for later, then flatten the input + // Tensor since we are working componentwise. We will restore + // the same shape in the returned Tensor. + const TensorShape& shape_prefix = buf_tensor.shape(); + + TensorShape sizes_shape = shape_prefix; + sizes_shape.AddDim(field_count); + Tensor* sizes_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor)); + + // This is used to allocate binary bufs if used. It serves only + // to define memory ownership. + std::vector<string> tmp_binary_bufs(message_count); + + // These are the actual buffers to use, which may be in tmp_binary_bufs + // or may be pointers into the buf_tensor. Either way they are not owned + // here. + std::vector<const string*> bufs; + + if (is_binary_ && !sanitize_) { + // Fast path. + for (int mi = 0; mi < message_count; ++mi) { + const string* buf = &buf_tensor.flat<string>()(mi); + bufs.push_back(buf); + } + } else { + // We will have to allocate a copy, either to convert from text to + // binary or to sanitize a binary proto. + for (int mi = 0; mi < message_count; ++mi) { + ReserializeMessage(ctx, buf_tensor.flat<string>()(mi), + &tmp_binary_bufs[mi]); + if (!ctx->status().ok()) { + return; + } + bufs.push_back(&tmp_binary_bufs[mi]); + } + } + + // Walk through all the strings in the input tensor, counting + // the number of fields in each. + // We can't allocate our actual output Tensor until we know the + // maximum repeat count, so we do a first pass through the serialized + // proto just counting fields. + // We always allocate at least one value so that optional fields + // are populated with default values - this avoids a TF + // conditional when handling the output data. + // The caller can distinguish between real data and defaults + // using the repeat count matrix that is returned by decode_proto. + std::vector<int32> max_sizes(field_count, 1); + for (int mi = 0; mi < message_count; ++mi) { + CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes); + if (!ctx->status().ok()) { + return; + } + } + + // Allocate the output tensors now that we've seen the max size. + // TODO(nix): Use allocate_output_or_forward_input for the largest + // output tensor. This can avoid one large allocation by re-using + // the memory of the input tensor. + std::vector<Tensor*> outputs(field_count); + for (int fi = 0; fi < field_count; ++fi) { + TensorShape flat_shape = {static_cast<int64>(message_count), + max_sizes[fi]}; + TensorShape out_shape = shape_prefix; + out_shape.AddDim(max_sizes[fi]); + + // Surprisingly we don't specify the types from the output_types + // attribute: that is done for us based on the Op declaration: + // REGISTER_OP(...) + // .Attr("output_types: list(type) >= 0") + // .Output("values: output_types") + OP_REQUIRES_OK(ctx, + // ctx->allocate_output(output_indices_[fi] + 1, + ctx->allocate_output(fields_[fi]->output_index + 1, + out_shape, &outputs[fi])); + } + + // Make the second pass through the serialized proto, decoding + // into preallocated tensors. + AccumulateFields(ctx, bufs, outputs); + } + + private: + // Copy a serialized message to binary, e.g. to handle text proto inputs. + void ReserializeMessage(OpKernelContext* ctx, const string& buf, + string* binary_buf) { + // Handle text protos by translating them to binary. + std::unique_ptr<Message> message(message_prototype_->New()); + OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed")); + + if (is_binary_) { + // If we get here we are sanitizing the input protobuf by parsing + // and reserializing it with a trusted (but very slow) library. + OP_REQUIRES(ctx, message->ParseFromString(buf), + errors::DataLoss("Unable to parse binary protobuf")); + } else { + OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()), + errors::DataLoss("Unable to parse text protobuf")); + } + + OP_REQUIRES(ctx, message->SerializeToString(binary_buf), + errors::DataLoss("Unable to reserialize text proto as binary")); + } + + // Count the number of occurrences of each requested field in a message batch. + void CountFields(OpKernelContext* ctx, int message_index, const string& buf, + Tensor* sizes_tensor, std::vector<int32>* max_sizes) { + int field_count = fields_.size(); + + CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()), + buf.size()); + + std::vector<int32> field_sizes(field_count, 0); + std::vector<CountCollector> counters; + counters.reserve(field_count); + for (int i = 0; i < field_count; i++) { + counters.emplace_back(&field_sizes[i]); + } + + Status st = Collect(&input, &counters); + if (st.ok() && !input.ConsumedEntireMessage()) { + st = errors::DataLoss("CountFields: Failed to consume entire buffer"); + } + if (kFailOnDecodeError) { + OP_REQUIRES_OK(ctx, st); // NOLINT + } + if (!st.ok()) { + // This code suppresses the corrupt proto, treating it as empty + // to avoid crashing the process. + LOG(WARNING) << "Proto counting error for message type " << message_type_ + << ": " << st; + + for (int fi = 0; fi < field_count; fi++) { + field_sizes[fi] = 0; + } + // Finished decoding this message. + return; + } + + // Update the size tensor and max repeat size for each field. + auto sizes = sizes_tensor->flat_inner_dims<int32>(); + for (int fi = 0; fi < field_count; fi++) { + int32 size = field_sizes[fi]; + sizes(message_index, fields_[fi]->output_index) = size; + if ((*max_sizes)[fi] < size) { + (*max_sizes)[fi] = size; + } + } + } + + // Parse fields from a serialized message into preallocated tensors. + void AccumulateFields(OpKernelContext* ctx, + const std::vector<const string*>& bufs, + std::vector<Tensor*> outputs) { + struct TensorInfo { + explicit TensorInfo(Tensor* tensor) { + // Note that we can decode only max_repeat_count values before overflow. + // No other bounds checking is done for repeated fields. For + // optional fields there is a check to make sure that only the last + // value on the wire appears in the output tensor. + dtype = tensor->dtype(); + last_dim_size = tensor->dim_size(tensor->dims() - 1); + + if (dtype != DT_STRING) { + const int element_size = DataTypeSize(dtype); + CHECK_GT(element_size, 0); + stride = last_dim_size * element_size; + + const int64 flatshape[1] = {tensor->NumElements() * element_size}; + data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data(); + } else { + // DataTypeSize() returns 0 for string types. + stride = last_dim_size * sizeof(string); + data = reinterpret_cast<uint8*>(tensor->flat<string>().data()); + } + } + + DataType dtype; + int last_dim_size; + int stride; + uint8* data; + }; + + int field_count = fields_.size(); + + std::vector<TensorInfo> tensors; + tensors.reserve(field_count); + for (int fi = 0; fi < field_count; fi++) { + tensors.emplace_back(outputs[fi]); + } + + for (int message_index = 0; message_index < bufs.size(); ++message_index) { + const string& buf = *bufs[message_index]; + + std::vector<DenseCollector> collectors; + collectors.reserve(field_count); + for (const TensorInfo& info : tensors) { + collectors.emplace_back(info.data + message_index * info.stride, + info.dtype, info.last_dim_size); + } + + // Fill in output tensors from the wire. + CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()), + buf.size()); + Status st = Collect(&input, &collectors); + if (st.ok() && !input.ConsumedEntireMessage()) { + st = errors::DataLoss( + "AccumulateFields: Failed to consume entire buffer"); + } + if (kFailOnDecodeError) { + OP_REQUIRES_OK(ctx, st); // NOLINT + } + if (!st.ok()) { + // This code suppresses the corrupt proto, treating it as empty + // to avoid crashing training. + LOG(WARNING) << "Proto counting error for message type " + << message_type_ << ": " << st; + } + + // Fill the remainder of the dense outputs with default values. + for (auto& collector : collectors) { + OP_REQUIRES_OK(ctx, collector.FillWithDefaults()); + } + } + } + + // Look up the FieldDescriptor for a particular field number. + bool LookupField(int field_number, int* field_index) { + // Look up the FieldDescriptor using linear search. + // TODO(nix): this could be sped up with binary search, but we are + // already way off the fastpath at this point. If you see a hotspot + // here, somebody is sending you very inefficient protos. + for (int fi = fields_.size() - 1; fi >= 0; fi--) { + if (field_number == fields_[fi]->number) { + *field_index = fi; + return true; + } + } + return false; + } + + // Traverses a serialized protobuf, dispatching values to the collectors. + template <class CollectorClass> + Status Collect(CodedInputStream* input, + std::vector<CollectorClass>* collectors) { + int last_good_field_index = -1; + bool fields_disordered = false; + int prev_field_number = -1; + int field_number = -1; + int last_good_field_number = -1; + int next_good_field_number = fields_[0]->number; + + // The 'tag' variable should always be treated as tainted. + for (uint32 tag = input->ReadTag(); + tag != 0 && WireFormatLite::GetTagWireType(tag) != + WireFormatLite::WIRETYPE_END_GROUP; + tag = input->ReadTag(), prev_field_number = field_number) { + field_number = WireFormatLite::GetTagFieldNumber(tag); + const FieldInfo* field = nullptr; + + // This takes advantage of the sorted field numbers in most serialized + // protos: it tries the next expected field first rather than doing + // a lookup by field number. + // TODO(nix): haberman@ suggests a hybrid approach with a lookup table + // for small field numbers and a hash table for larger ones. This would + // be a simpler approach that should offer comparable speed in most + // cases. + if (field_number == last_good_field_number) { + field = fields_[last_good_field_index].get(); + } else { + if (field_number < prev_field_number) { + fields_disordered = true; + } + + // If fields are out of order, fall back to slow lookup. + if (fields_disordered) { + int field_index; + if (LookupField(field_number, &field_index)) { + field = fields_[field_index].get(); + last_good_field_index = field_index; + } + } else { + // If we see a field that is past the next field we want, + // it was empty. Look for the one after that. + // Repeat until we run out of fields that we care about. + while (field_number >= next_good_field_number) { + if (field_number == next_good_field_number) { + last_good_field_number = field_number; + field = fields_[last_good_field_index + 1].get(); + } + + // Start looking for the field after the current one. + ++last_good_field_index; + if (last_good_field_index < fields_.size() - 1) { + next_good_field_number = + fields_[last_good_field_index + 1]->number; + } else { + // Saw something past the last field we care about. + // Continue parsing the message just in case there + // are disordered fields later, but any remaining + // ordered fields will have no effect. + next_good_field_number = INT_MAX; + } + } + } + } + + if (!field) { + // Unknown and unrequested fields are skipped. + if (!WireFormatLite::SkipField(input, tag)) { + return errors::DataLoss("Failed skipping unrequested field"); + } + continue; + } + + Status st = CollectField(*field, WireFormatLite::GetTagWireType(tag), + input, &(*collectors)[last_good_field_index]); + if (!st.ok()) { + return st; + } + } + return Status::OK(); + } + + // Collects values for a single field. + template <class CollectorClass> + Status CollectField(const FieldInfo& field, + WireFormatLite::WireType wire_type, + CodedInputStream* input, CollectorClass* collector) { + // The wire format library defines the same constants used in + // descriptor.proto. This static_cast is safe because they + // are guaranteed to stay in sync. + // We need the field type from the FieldDescriptor here + // because the wire format doesn't tell us anything about + // what happens inside a packed repeated field: there is + // enough information in the wire format to skip the + // whole field but not enough to know how to parse what's + // inside. For that we go to the schema. + WireFormatLite::WireType schema_wire_type = + WireFormatLite::WireTypeForFieldType(field.type); + + // Handle packed repeated fields. SkipField would skip the + // whole length-delimited blob without letting us count the + // values, so we have to scan them ourselves. + if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED && + schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { + // Handle packed repeated primitives. + int length; + if (!input->ReadVarintSizeAsInt(&length)) { + return errors::DataLoss("CollectField: Failed reading packed size"); + } + Status st = collector->ReadPackedValues(input, field, length); + if (!st.ok()) { + return st; + } + return Status::OK(); + } + + // Read ordinary values, including strings, bytes, and messages. + if (wire_type != schema_wire_type) { + if (!WireFormatLite::SkipField( + input, WireFormatLite::MakeTag(field.number, wire_type))) { + return errors::DataLoss( + "CollectField: Failed skipping malformed field"); + } + return Status::OK(); + } + return collector->ReadValue(input, field); + } + + string message_type_; + // Note that fields are sorted by increasing field number, + // which is not in general the order given by the user-specified + // field_names and output_types Op attributes. + std::vector<std::unique_ptr<const FieldInfo>> fields_; + + // Owned_desc_pool_ is null when using descriptor_source=local. + std::unique_ptr<DescriptorPool> owned_desc_pool_; + DynamicMessageFactory message_factory_; + const Message* message_prototype_; + + // True if decoding binary format, false if decoding text format. + bool is_binary_; + + // True if the protos should be sanitized before parsing. + // Enables the initial protobuf sanitizer, which is much + // more expensive than the decoder. The flag defaults to true + // but can be set to false for trusted sources. + // TODO(nix): flip the default to false when the fast decoder + // has passed security review. + bool sanitize_; + + TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp); +}; + +REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU), + DecodeProtoOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc new file mode 100644 index 0000000000..3b02ae52a2 --- /dev/null +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -0,0 +1,591 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// EncodeProto is a TensorFlow Op which serializes tensors into +// arbitrary protobufs. +// +// See the docstring in ../ops/encode_proto_op.cc for usage of the op. +// +// This implementation writes the serialized format using a handful of +// calls from the WireFormatLite API. + +#include <memory> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/proto/descriptors.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::protobuf::Descriptor; +using ::tensorflow::protobuf::DescriptorPool; +using ::tensorflow::protobuf::FieldDescriptor; +using ::tensorflow::protobuf::internal::WireFormatLite; +using ::tensorflow::protobuf::io::CodedOutputStream; +using ::tensorflow::protobuf::io::StringOutputStream; + +// Computes the total serialized size for a packed repeated field. +// For fixed-size types this can just multiply, but for variable-sized +// types it has to iterate through the values in the tensor. +template <WireFormatLite::FieldType FieldType, typename TensorT> +size_t TotalPackedSize(const Tensor& input, int message_index, int size); + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kDoubleSize; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kFloatSize; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kFloatSize; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::Int64Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::UInt64Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int32>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::Int32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kFixed64Size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kFixed32Size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kFixed32Size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input, + int message_index, + int size) { + return size * WireFormatLite::kBoolSize; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::UInt32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int32>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::UInt32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int32>(); + for (int64 i = 0; i < size; i++) { + data_size += + WireFormatLite::EnumSize(input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>( + const Tensor& input, int message_index, int size) { + return size * WireFormatLite::kSFixed32Size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>( + const Tensor& input, int message_index, int size) { + return size * WireFormatLite::kSFixed64Size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int32>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::SInt32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::SInt64Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +// Writes a possibly repeated primitive field. +// TensorFlow does not have unsigned types, so we decode them to signed and +// encode them back to unsigned. +template <typename TensorT, typename ProtoT, + WireFormatLite::FieldType FieldType, + void Writer(ProtoT, CodedOutputStream*)> +void WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { + auto wire_type = WireFormatLite::WireTypeForFieldType( + WireFormatLite::FieldType(field_desc.type())); + + auto input_t = input.flat_inner_dims<TensorT>(); + if (field_desc.options().packed()) { + // Write the tag for the packed field. + WireFormatLite::WriteTag(field_desc.number(), + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output); + + // Write the total packed length. + size_t data_size = + TotalPackedSize<FieldType, TensorT>(input, message_index, size); + output->WriteVarint32(data_size); + + // Write individual values. + for (int64 i = 0; i < size; i++) { + // Note implicit cast from signed to unsigned. + const ProtoT& value = input_t(static_cast<int64>(message_index), i); + Writer(value, output); + } + } else { + for (int64 i = 0; i < size; i++) { + WireFormatLite::WriteTag(field_desc.number(), wire_type, output); + + // Note implicit cast from signed to unsigned. + const ProtoT& value = input_t(static_cast<int64>(message_index), i); + Writer(value, output); + } + } +} + +// Writes a possibly repeated string, bytes, or message field. +template <typename T, void Writer(int, const T&, CodedOutputStream*)> +void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { + auto input_t = input.flat_inner_dims<T>(); + for (int64 i = 0; i < size; i++) { + const T& value = input_t(static_cast<int64>(message_index), i); + // TODO(nix): there doesn't seem to be an inlined version of + // WireFormatLite::WriteString or its relatives, which might allow a + // small speedup. + Writer(field_desc.number(), value, output); + } +} + +// Writes a group field. +// Groups are treated like submessages, but tag-delimited +// instead of length-delimited. WireFormatLite handles this +// differently so we code it ourselves. +void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { + auto input_t = input.flat_inner_dims<string>(); + for (int64 i = 0; i < size; i++) { + const string& value = input_t(static_cast<int64>(message_index), i); + WireFormatLite::WriteTag(field_desc.number(), + WireFormatLite::WIRETYPE_START_GROUP, output); + // Note the use of WriteRaw instead of WriteString to skip the length. + output->WriteRaw(value.data(), value.size()); + WireFormatLite::WriteTag(field_desc.number(), + WireFormatLite::WIRETYPE_END_GROUP, output); + } +} + +// Writes a (possibly repeated) field into an output stream. +// It is the caller's responsibility to ensure that the type of +// the input tensor is compatible with the type of the proto +// field descriptor, and that (message_index, size-1) is within +// bounds. +void WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { + DataType tf_type = input.dtype(); + + switch (field_desc.type()) { + case WireFormatLite::TYPE_DOUBLE: + return WriteField<double, double, WireFormatLite::TYPE_DOUBLE, + WireFormatLite::WriteDoubleNoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_FLOAT: + switch (tf_type) { + case DataType::DT_FLOAT: + return WriteField<float, float, WireFormatLite::TYPE_FLOAT, + WireFormatLite::WriteFloatNoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_DOUBLE: + return WriteField<double, float, WireFormatLite::TYPE_FLOAT, + WireFormatLite::WriteFloatNoTag>( + field_desc, input, message_index, size, output); + default: + return; + } + case WireFormatLite::TYPE_INT64: + return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64, + WireFormatLite::WriteInt64NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_UINT64: + return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64, + WireFormatLite::WriteUInt64NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_INT32: + return WriteField<int32, int32, WireFormatLite::TYPE_INT32, + WireFormatLite::WriteInt32NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_FIXED64: + return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64, + WireFormatLite::WriteFixed64NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_FIXED32: + switch (tf_type) { + case DataType::DT_INT64: + return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32, + WireFormatLite::WriteFixed32NoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_INT32: + return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32, + WireFormatLite::WriteFixed32NoTag>( + field_desc, input, message_index, size, output); + default: + return; + } + case WireFormatLite::TYPE_BOOL: + return WriteField<bool, bool, WireFormatLite::TYPE_BOOL, + WireFormatLite::WriteBoolNoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_STRING: + return WriteVarLenField<string, WireFormatLite::WriteString>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_GROUP: + return WriteGroup(field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_MESSAGE: + return WriteVarLenField<string, WireFormatLite::WriteBytes>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_BYTES: + return WriteVarLenField<string, WireFormatLite::WriteBytes>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_UINT32: + switch (tf_type) { + case DataType::DT_INT64: + return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32, + WireFormatLite::WriteUInt32NoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_INT32: + return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32, + WireFormatLite::WriteUInt32NoTag>( + field_desc, input, message_index, size, output); + default: + return; + } + case WireFormatLite::TYPE_ENUM: + return WriteField<int32, int32, WireFormatLite::TYPE_ENUM, + WireFormatLite::WriteEnumNoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_SFIXED32: + return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32, + WireFormatLite::WriteSFixed32NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_SFIXED64: + return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64, + WireFormatLite::WriteSFixed64NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_SINT32: + return WriteField<int32, int32, WireFormatLite::TYPE_SINT32, + WireFormatLite::WriteSInt32NoTag>( + field_desc, input, message_index, size, output); + case WireFormatLite::TYPE_SINT64: + return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64, + WireFormatLite::WriteSInt64NoTag>( + field_desc, input, message_index, size, output); + // default: intentionally omitted in order to enable static checking. + } +} + +// Checks that a Protobuf field is compatible with a TensorFlow datatype. +// This is separated from WriteField to lift it out of the inner loop. +bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) { + switch (field_desc.type()) { + case WireFormatLite::TYPE_DOUBLE: + return tf_type == DataType::DT_DOUBLE; + case WireFormatLite::TYPE_FLOAT: + return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE; + case WireFormatLite::TYPE_INT64: + case WireFormatLite::TYPE_SFIXED64: + case WireFormatLite::TYPE_SINT64: + return tf_type == DataType::DT_INT64; + case WireFormatLite::TYPE_UINT64: + return tf_type == DataType::DT_INT64; + case WireFormatLite::TYPE_INT32: + case WireFormatLite::TYPE_ENUM: + case WireFormatLite::TYPE_SFIXED32: + case WireFormatLite::TYPE_SINT32: + return tf_type == DataType::DT_INT32; + case WireFormatLite::TYPE_FIXED64: + return tf_type == DataType::DT_INT64; + case WireFormatLite::TYPE_FIXED32: + case WireFormatLite::TYPE_UINT32: + return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32; + case WireFormatLite::TYPE_BOOL: + return tf_type == DataType::DT_BOOL; + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + case WireFormatLite::TYPE_BYTES: + return tf_type == DataType::DT_STRING; + // default: intentionally omitted in order to enable static checking. + } + return false; +} + +class EncodeProtoOp : public OpKernel { + public: + explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) { + string descriptor_source; + OP_REQUIRES_OK(context, + context->GetAttr("descriptor_source", &descriptor_source)); + // We always get back a desc_pool, but we may not own it. If we own it, + // owned_desc_pool_ will be filled in. + DescriptorPool const* desc_pool; + OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source, + &desc_pool, &owned_desc_pool_)); + + string message_type; + OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type)); + const Descriptor* message_desc = + desc_pool->FindMessageTypeByName(message_type); + OP_REQUIRES(context, message_desc != nullptr, + errors::InvalidArgument("No descriptor found for message type ", + message_type)); + + OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_)); + + // Gather the field descriptors for the given field_names. + field_descs_.resize(field_names_.size()); + for (int i = 0; i < field_names_.size(); i++) { + const string& name = field_names_[i]; + auto field_desc = message_desc->FindFieldByName(name); + OP_REQUIRES(context, field_desc != nullptr, + errors::InvalidArgument("Unknown field: ", name, + " in message type ", message_type)); + + field_descs_[i] = field_desc; + } + + // Build a list of indices into field_descs sorted by increasing + // field_number. This will be used to output fields in sorted order, + // which is strongly encouraged when serializing protobufs. + sorted_field_index_.resize(field_names_.size()); + // Start with the fields sorted by current index. + for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i; + // Then sort the field indices by their proto field number. + std::sort(sorted_field_index_.begin(), sorted_field_index_.end(), + [this](int a, int b) -> bool { + return field_descs_[a]->number() < field_descs_[b]->number(); + }); + } + + void Compute(OpKernelContext* cx) override { + const Tensor* sizes_tensor; + OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor)); + + OpInputList values; + OP_REQUIRES_OK(cx, cx->input_list("values", &values)); + + OP_REQUIRES(cx, field_descs_.size() == values.size(), + errors::InvalidArgument( + "Length of inputs list must match field_names")); + + // Check the arguments for consistency. + TensorShape common_prefix; + int message_count; + for (int i = 0; i < field_descs_.size(); i++) { + const Tensor& v = values[i]; + + // The type of each value tensor must match the corresponding field. + OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()), + errors::InvalidArgument( + "Incompatible type for field " + field_names_[i] + + ". Saw dtype: ", + DataTypeString(v.dtype()), + " but field type is: ", field_descs_[i]->type_name())); + + // All value tensors must have the same shape prefix (i.e. batch size). + TensorShape shape_prefix = v.shape(); + shape_prefix.RemoveDim(shape_prefix.dims() - 1); + + // Do some initialization on the first input value. The rest will + // have to match this one. + if (i == 0) { + OP_REQUIRES(cx, v.dims() >= 1, + errors::InvalidArgument( + "Expected value to be at least a vector, saw shape: ", + v.shape().DebugString())); + common_prefix = shape_prefix; + message_count = common_prefix.num_elements(); + } else { + OP_REQUIRES(cx, shape_prefix == common_prefix, + errors::InvalidArgument( + "Values must match up to the last dimension")); + } + } + + TensorShape expected_sizes_shape = common_prefix; + expected_sizes_shape.AddDim(field_descs_.size()); + + OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape, + errors::InvalidArgument( + "sizes should be batch_size + [len(field_names)]. Saw: ", + sizes_tensor->shape().DebugString(), + " but expected: ", expected_sizes_shape.DebugString())); + + auto sizes = sizes_tensor->flat_inner_dims<int32>(); + + for (int i = 0; i < field_descs_.size(); ++i) { + const Tensor& v = values[i]; + int max_size = v.dim_size(v.dims() - 1); + + // The last dimension of a value tensor must be greater than the + // corresponding + // size in the sizes tensor. + for (int message_index = 0; message_index < message_count; + message_index++) { + OP_REQUIRES( + cx, sizes(message_index, i) <= max_size, + errors::InvalidArgument( + "Size to write must not be larger than value tensor; but saw: ", + sizes(message_index, i), " > ", max_size, " at message ", + message_index, " field ", i)); + } + } + + // This pointer is owned by the context. + Tensor* output_tensor; + OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor)); + + auto bufs = output_tensor->flat<string>(); + for (int message_index = 0; message_index < message_count; + message_index++) { + // TODO(nix): possibly optimize allocation here by calling + // bufs(message_index).reserve(DEFAULT_BUF_SIZE); + StringOutputStream output_string(&bufs(message_index)); + CodedOutputStream out(&output_string); + // Write fields in ascending field_number order. + for (int i : sorted_field_index_) { + auto& field_desc = *field_descs_[i]; + const Tensor& v = values[i]; + int size = sizes(message_index, i); + if (!size) continue; + WriteField(field_desc, v, message_index, size, &out); + } + } + } + + private: + std::vector<string> field_names_; + std::vector<const FieldDescriptor*> field_descs_; + + // Owned_desc_pool_ is null when using descriptor_source=local. + std::unique_ptr<DescriptorPool> owned_desc_pool_; + + // Contains indices into field_names_, sorted by field number since + // that's the order of writing. + std::vector<int> sorted_field_index_; + + TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp); +}; + +REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/rpc_op.cc b/tensorflow/core/kernels/rpc_op.cc new file mode 100644 index 0000000000..2447ef5040 --- /dev/null +++ b/tensorflow/core/kernels/rpc_op.cc @@ -0,0 +1,129 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// RpcOp is a TensorFlow op that sends and receives arbitrary messages. +// +// See docs in ../ops/rpc_op.cc. + +#include <memory> +#include <string> +#include <vector> + +#include "third_party/eigen3/Eigen/Core" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/rpc/call_container.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" + +namespace tensorflow { + +class RpcOp : public AsyncOpKernel { + public: + explicit RpcOp(OpKernelConstruction* context) : AsyncOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("protocol", &protocol_)); + OP_REQUIRES(context, !protocol_.empty(), + errors::InvalidArgument("protocol must be non-empty.")); + bool fail_fast; + OP_REQUIRES_OK(context, context->GetAttr("fail_fast", &fail_fast)); + int64 timeout_in_ms; + OP_REQUIRES_OK(context, context->GetAttr("timeout_in_ms", &timeout_in_ms)); + + RPCFactoryRegistry::RPCFactoryFn* rpc_factory_fn = + RPCFactoryRegistry::Global()->Get(protocol_); + OP_REQUIRES(context, rpc_factory_fn != nullptr, + errors::InvalidArgument("The protocol ", protocol_, + " was not recognized.")); + + rpc_factory_.reset((*rpc_factory_fn)(context, fail_fast, timeout_in_ms)); + } + + ~RpcOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& address_t = ctx->input(0); + const Tensor& method_t = ctx->input(1); + const Tensor& request_t = ctx->input(2); + + OP_REQUIRES_ASYNC( + ctx, address_t.dims() == 0 || address_t.dims() == 1, + errors::InvalidArgument("address must be a scalar or vector."), done); + OP_REQUIRES_ASYNC( + ctx, method_t.dims() == 0 || method_t.dims() == 1, + errors::InvalidArgument("method must be a scalar or vector."), done); + OP_REQUIRES_ASYNC( + ctx, request_t.dims() == 0 || request_t.dims() == 1, + errors::InvalidArgument("request must be a scalar or vector."), done); + + TensorShape output_shape({}); + for (const Tensor& t : {address_t, method_t, request_t}) { + if (t.dims() == 1) { + OP_REQUIRES_ASYNC( + ctx, + output_shape.dims() == 0 || + output_shape.dim_size(0) == t.dim_size(0), + errors::InvalidArgument( + "Input vector shapes don't match: ", output_shape.DebugString(), + " vs. ", t.shape().DebugString()), + done); + output_shape = t.shape(); + } + } + + Tensor* response_t; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->allocate_output(0, output_shape, &response_t), done); + + const bool try_rpc = (ctx->num_outputs() > 1); + + Tensor* status_code_t = nullptr; + Tensor* status_message_t = nullptr; + if (try_rpc) { + OP_REQUIRES_OK_ASYNC( + ctx, ctx->allocate_output(1, output_shape, &status_code_t), done); + OP_REQUIRES_OK_ASYNC( + ctx, ctx->allocate_output(2, output_shape, &status_message_t), done); + } + + if (request_t.NumElements() == 0) { + // Special case, we finished early! + done(); + return; + } + + int64 num_elements = output_shape.num_elements(); + + rpc_factory_->Call(ctx, num_elements, address_t, method_t, request_t, + try_rpc, response_t, status_code_t, status_message_t, + std::move(done)); + } + + private: + string protocol_; + std::unique_ptr<RPCFactory> rpc_factory_; + + TF_DISALLOW_COPY_AND_ASSIGN(RpcOp); +}; + +REGISTER_KERNEL_BUILDER(Name("Rpc").Device(DEVICE_CPU), RpcOp); +REGISTER_KERNEL_BUILDER(Name("TryRpc").Device(DEVICE_CPU), RpcOp); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/decode_proto_ops.cc b/tensorflow/core/ops/decode_proto_ops.cc new file mode 100644 index 0000000000..3f6fb2f582 --- /dev/null +++ b/tensorflow/core/ops/decode_proto_ops.cc @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("DecodeProtoV2") + .Input("bytes: string") + .Attr("message_type: string") + .Attr("field_names: list(string)") + .Attr("output_types: list(type) >= 0") + .Attr("descriptor_source: string = 'local://'") + .Attr("message_format: string = 'binary'") + .Attr("sanitize: bool = false") + .Output("sizes: int32") + .Output("values: output_types") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input = c->input(0); + + std::vector<tensorflow::DataType> output_types; + TF_RETURN_IF_ERROR(c->GetAttr("output_types", &output_types)); + + ShapeHandle sizes; + TF_RETURN_IF_ERROR( + c->Concatenate(input, c->Vector(output_types.size()), &sizes)); + c->set_output(0, sizes); + + // TODO(nix): to do the best possible job of shape inference, we + // should examine the proto descriptors here in order to set shape + // indices to 1 instead of unknown for optional or required fields. + // Any general-purpose code will have to handle the unknown case, + // but there might be XLA code that could be sped up with the additional + // knowledge. + for (int i = 0; i < output_types.size(); ++i) { + ShapeHandle values; + TF_RETURN_IF_ERROR( + c->Concatenate(input, c->Vector(c->UnknownDim()), &values)); + c->set_output(i + 1, values); + } + + return Status::OK(); + }); + +// TODO(nix): Consider adding an additional input argument that truncates +// repeated fields to a maximum count. For now this could be done by passing +// the output through tf.slice. + +// TODO(nix): define missing value behavior. + +} // namespace tensorflow diff --git a/tensorflow/core/ops/encode_proto_ops.cc b/tensorflow/core/ops/encode_proto_ops.cc new file mode 100644 index 0000000000..f5ec3056e3 --- /dev/null +++ b/tensorflow/core/ops/encode_proto_ops.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +REGISTER_OP("EncodeProto") + .Input("sizes: int32") + .Input("values: Tinput_types") + .Attr("field_names: list(string)") + .Attr("message_type: string") + .Attr("descriptor_source: string = 'local://'") + .Attr("Tinput_types: list(type)") + .Output("bytes: string") + .SetShapeFn([](InferenceContext* c) { + int first_field_index = 1; + int num_fields = c->num_inputs() - 1; + + ShapeHandle output; + for (int i = num_fields - 1; i >= 0; --i) { + ShapeHandle input = c->input(first_field_index + i); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(input, 2, &input)); + ShapeHandle inner; + TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &inner)); + TF_RETURN_IF_ERROR(c->Merge(inner, output, &output)); + } + + c->set_output(0, output); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/rpc_ops.cc b/tensorflow/core/ops/rpc_ops.cc new file mode 100644 index 0000000000..72fda5e6eb --- /dev/null +++ b/tensorflow/core/ops/rpc_ops.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using tensorflow::shape_inference::DimensionHandle; +using tensorflow::shape_inference::InferenceContext; +using tensorflow::shape_inference::ShapeHandle; + +Status RpcShapeOp(InferenceContext* c, bool try_rpc) { + ShapeHandle address; + ShapeHandle method; + ShapeHandle request; + ShapeHandle output; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &address)); + if (c->Rank(address) == 1) { + TF_RETURN_IF_ERROR(c->Merge(output, address, &output)); + } + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 1, &method)); + if (c->Rank(method) == 1) { + TF_RETURN_IF_ERROR(c->Merge(output, method, &output)); + } + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &request)); + if (c->Rank(request) == 1) { + TF_RETURN_IF_ERROR(c->Merge(output, request, &output)); + } + if (!c->RankKnown(output)) { + output = request; + } + c->set_output(0, output); // response + if (try_rpc) { + c->set_output(1, output); // status_code + c->set_output(2, output); // status_message + } + return Status::OK(); +} + +REGISTER_OP("Rpc") + .Input("address: string") + .Input("method: string") + .Input("request: string") + .Attr("protocol: string = ''") + .Attr("fail_fast: bool = true") + .Attr("timeout_in_ms: int = 0") + .Output("response: string") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + return RpcShapeOp(c, /*try_rpc=*/false); + }); + +REGISTER_OP("TryRpc") + .Input("address: string") + .Input("method: string") + .Input("request: string") + .Attr("protocol: string = ''") + .Attr("fail_fast: bool = true") + .Attr("timeout_in_ms: int = 0") + .Output("response: string") + .Output("status_code: int32") + .Output("status_message: string") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + return RpcShapeOp(c, /*try_rpc=*/true); + }); + +} // namespace tensorflow diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD new file mode 100644 index 0000000000..ade14ed162 --- /dev/null +++ b/tensorflow/core/util/proto/BUILD @@ -0,0 +1,62 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "decode", + hdrs = ["decode.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "descriptors", + srcs = ["descriptors.cc"], + hdrs = ["descriptors.h"], + deps = [ + ":descriptor_pool_registry", + ":local_descriptor_pool_registration", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "descriptor_pool_registry", + srcs = ["descriptor_pool_registry.cc"], + hdrs = ["descriptor_pool_registry.h"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "descriptor_pool_registry_test", + srcs = ["descriptor_pool_registry_test.cc"], + deps = [ + ":descriptor_pool_registry", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# Depending on this target adds support for using the special +# value "local://" (or "") for descriptor source, in which case +# descriptors linked into the code will be searched. +cc_library( + name = "local_descriptor_pool_registration", + srcs = ["local_descriptor_pool_registration.cc"], + deps = [ + ":descriptor_pool_registry", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h new file mode 100644 index 0000000000..74634a356a --- /dev/null +++ b/tensorflow/core/util/proto/decode.h @@ -0,0 +1,592 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Inline functions for parsing the protocol buffers wire format. +// +// These functions have been optimized at the expense of safety. +// They are broken out into a separate file for readability but are +// not intended for use by clients other than the decode_proto op. +// +// The calling code in the decode_proto op does some fairly +// complicated things to ensure that this code is called +// safely. Changes to this code should be thoroughly fuzz tested. + +#ifndef TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_ +#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace internal { + +using tensorflow::protobuf::internal::WireFormatLite; +using tensorflow::protobuf::io::CodedInputStream; +using tensorflow::protobuf::io::CodedOutputStream; +using tensorflow::protobuf::io::StringOutputStream; + +// Converts an uint64 to an int64 without loss of information. +// Unsigned values greater than INT64_MAX are represented as +// negative numbers by wrapping (same as twos-complement bit equivalence). +inline int64 WrapUnsignedAsSigned64(uint64 unsigned_value) { + // For a detailed explanation of why this works to wrap unsigned ints, see + // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior + // Both if tests should be optimized out. + if (unsigned_value <= INT64_MAX) { + return static_cast<int64>(unsigned_value); + } + // The C++ spec allows an architecture where this test is required. + if (unsigned_value >= INT64_MIN) { + return static_cast<int64>(unsigned_value - INT64_MIN) + INT64_MIN; + } + return 0; // This should never occur. +} + +// Converts an uint32 to an int32 without loss of information. +// Unsigned values greater than INT_MAX are represented as +// negative numbers by wrapping (same as twos-complement bit equivalence). +inline int32 WrapUnsignedAsSigned32(uint32 unsigned_value) { + // For a detailed explanation of why this works to wrap unsigned ints, see + // http://stackoverflow.com/questions/13150449/efficient-unsigned-to-signed-cast-avoiding-implementation-defined-behavior + // Both if tests should be optimized out. + if (unsigned_value <= INT_MAX) { + return static_cast<int32>(unsigned_value); + } + // The C++ spec allows an architecture where this test is required. + if (unsigned_value >= INT_MIN) { + return static_cast<int32>(unsigned_value - INT_MIN) + INT_MIN; + } + return 0; // This should never occur. +} + +// Reads a single varint32 from a byte array. +// It is the caller's responsibility to ensure that there is enough +// space in the buffer. +// The ok value will be set to false if the buffer does not contain +// a valid varint. +inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok, + uint64* value); + +// Reads a single varint32 from a byte array. +// It is the caller's responsibility to ensure that there is enough +// space in the buffer. +// The ok value will be set to false if the buffer does not contain +// a valid varint. +// This is slightly less efficient than the private version in +// coded_stream.cc but we duplicate less code by calling +// the 64 bit version instead of copying the code. +inline const uint8* ReadVarint32FromArray(const uint8* buffer, bool* ok, + uint32* value) { + uint64 tmp; + const uint8* buf = ReadVarint64FromArray(buffer, ok, &tmp); + *value = tmp & 0xffffffff; + return buf; +} + +// Reads a single proto field value from a byte array into an array. +// The array is part of a Tensor that was allocated by the caller +// with type TensorType, while DeclaredType is the proto field type. +template <class TensorType, enum WireFormatLite::FieldType DeclaredType> +const uint8* ReadFromArray(const uint8* buf, TensorType* value); + +template <> +inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>( + const uint8* buf, int32* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = static_cast<int32>(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>( + const uint8* buf, int64* value) { + uint64 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint64FromArray(buf, &unused_ok, &temp); + *value = WrapUnsignedAsSigned64(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>( + const uint8* buf, int64* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = temp; + return buf; +} + +template <> +inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>( + const uint8* buf, int32* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = WrapUnsignedAsSigned32(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>( + const uint8* buf, int64* value) { + uint64 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint64FromArray(buf, &unused_ok, &temp); + *value = static_cast<int64>(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SINT32>( + const uint8* buf, int32* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = WireFormatLite::ZigZagDecode32(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>( + const uint8* buf, int64* value) { + uint64 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint64FromArray(buf, &unused_ok, &temp); + *value = WireFormatLite::ZigZagDecode64(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>( + const uint8* buf, int64* value) { + uint32 temp; + buf = WireFormatLite::ReadPrimitiveFromArray<uint32, + WireFormatLite::TYPE_FIXED32>( + buf, &temp); + *value = temp; + return buf; +} + +template <> +inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>( + const uint8* buf, int32* value) { + uint32 temp; + buf = WireFormatLite::ReadPrimitiveFromArray<uint32, + WireFormatLite::TYPE_FIXED32>( + buf, &temp); + *value = WrapUnsignedAsSigned32(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>( + const uint8* buf, int64* value) { + protobuf_uint64 temp; + buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64, + WireFormatLite::TYPE_FIXED64>( + buf, &temp); + *value = WrapUnsignedAsSigned64(temp); + return buf; +} + +template <> +inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>( + const uint8* buf, int32* value) { + return WireFormatLite::ReadPrimitiveFromArray<int32, + WireFormatLite::TYPE_SFIXED32>( + buf, value); +} + +template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED64>( + const uint8* buf, int64* value) { + protobuf_int64 temp; + buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_int64, + WireFormatLite::TYPE_SFIXED64>( + buf, &temp); + *value = temp; + return buf; +} + +template <> +inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>( + const uint8* buf, float* value) { + return WireFormatLite::ReadPrimitiveFromArray<float, + WireFormatLite::TYPE_FLOAT>( + buf, value); +} + +template <> +inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>( + const uint8* buf, double* value) { + return WireFormatLite::ReadPrimitiveFromArray<double, + WireFormatLite::TYPE_DOUBLE>( + buf, value); +} + +template <> +inline const uint8* ReadFromArray<bool, WireFormatLite::TYPE_BOOL>( + const uint8* buf, bool* value) { + uint64 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint64FromArray(buf, &unused_ok, &temp); + *value = temp != 0; + return buf; +} + +template <> +inline const uint8* ReadFromArray<int, WireFormatLite::TYPE_ENUM>( + const uint8* buf, int* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = static_cast<int>(temp); + return buf; +} + +// Reads packed values from an array. +// Stride is set to 1 for repeated fields, and 0 for non-repeated fields +// (where any value overwrites previous values). +template <class TensorType, enum WireFormatLite::FieldType DeclaredType> +inline int ReadPackedPrimitives(const void* bufp, const size_t len, + const int index, const int stride, + void* datap) { + const uint8* buf = reinterpret_cast<const uint8*>(bufp); + const uint8* bound = buf + len; + TensorType* data = reinterpret_cast<TensorType*>(datap) + index; + int count; + + // This could overrun the bound by stride-1. This is defended + // against in the caller, where it ensures that the input buffer + // contains complete values. + for (count = 0; buf < bound; count += stride) { + buf = ReadFromArray<TensorType, DeclaredType>(buf, data + count); + } + return count; +} + +// Reads a primitive value field from a serialized proto. +// The value is parsed from the serialized format, then static_cast +// to the desired type for TensorFlow and stored. +template <class ValueType, class TensorType, + enum WireFormatLite::FieldType DeclaredType> +inline Status ReadPrimitive(CodedInputStream* input, int index, void* data) { + ValueType v; + if (!WireFormatLite::ReadPrimitive<ValueType, DeclaredType>(input, &v)) { + return errors::DataLoss("Failed reading primitive"); + } + + reinterpret_cast<TensorType*>(data)[index] = v; + return Status::OK(); +} + +// Reads a string, submessage, or other variable-length field from a +// serialized proto. +// May read all or part of a repeated field. +inline Status ReadBytes(CodedInputStream* input, int index, void* datap) { + string* data = reinterpret_cast<string*>(datap) + index; + if (!WireFormatLite::ReadBytes(input, data)) { + return errors::DataLoss("Failed reading bytes"); + } + return Status::OK(); +} + +// Reads a tag-delimited field (TYPE_GROUP) from a serialized proto, +// as a bytestring. +inline Status ReadGroupBytes(CodedInputStream* input, int field_number, + int index, void* datap) { + // WireFormatLite::SkipField has an option to emit the + // skipped bytes to an output stream. We could do better by implementing our + // own scanner but this is simpler for now. + // TODO(nix): there is a faster way to grab TYPE_GROUP bytes by relying + // on input->IsFlat() == true and using input->GetDirectBufferPointer() + // with input->CurrentPosition(). + string* data = reinterpret_cast<string*>(datap) + index; + StringOutputStream string_stream(data); + CodedOutputStream out(&string_stream); + if (!WireFormatLite::SkipField( + input, + WireFormatLite::MakeTag(field_number, + WireFormatLite::WIRETYPE_START_GROUP), + &out)) { + return errors::DataLoss("Failed reading group"); + } + return Status::OK(); +} + +// Reads a single field value from a CodedInputStream into a tensor. +inline Status ReadValue(CodedInputStream* input, + WireFormatLite::FieldType field_type, int field_number, + DataType dtype, int index, void* datap) { + // Dispatch to the appropriately typed field reader based on the + // schema type. + switch (field_type) { + case WireFormatLite::TYPE_DOUBLE: + return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>( + input, index, datap); + case WireFormatLite::TYPE_FLOAT: + if (dtype == DataType::DT_FLOAT) { + return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>( + input, index, datap); + } + if (dtype == DataType::DT_DOUBLE) { + return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>( + input, index, datap); + } + // Any case that reaches this point should have triggered an error + // already. + return errors::DataLoss("Failed reading TYPE_FLOAT"); + case WireFormatLite::TYPE_INT64: + return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>( + input, index, datap); + case WireFormatLite::TYPE_UINT64: + return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>( + input, index, datap); + case WireFormatLite::TYPE_INT32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>( + input, index, datap); + case WireFormatLite::TYPE_FIXED64: + return ReadPrimitive<protobuf_uint64, int64, + WireFormatLite::TYPE_FIXED64>(input, index, datap); + case WireFormatLite::TYPE_FIXED32: + if (dtype == DataType::DT_INT64) { + return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>( + input, index, datap); + } + if (dtype == DataType::DT_INT32) { + return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>( + input, index, datap); + } + // Any case that reaches this point should have triggered an error + // already. + return errors::DataLoss("Failed reading TYPE_FIXED32"); + case WireFormatLite::TYPE_BOOL: + return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index, + datap); + case WireFormatLite::TYPE_STRING: + return ReadBytes(input, index, datap); + case WireFormatLite::TYPE_GROUP: + return ReadGroupBytes(input, field_number, index, datap); + case WireFormatLite::TYPE_MESSAGE: + return ReadBytes(input, index, datap); + case WireFormatLite::TYPE_BYTES: + return ReadBytes(input, index, datap); + case WireFormatLite::TYPE_UINT32: + if (dtype == DataType::DT_INT64) { + return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>( + input, index, datap); + } + if (dtype == DataType::DT_INT32) { + return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>( + input, index, datap); + } + // Any case that reaches this point should have triggered an error + // already. + return errors::DataLoss("Failed reading TYPE_UINT32"); + case WireFormatLite::TYPE_ENUM: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>( + input, index, datap); + case WireFormatLite::TYPE_SFIXED32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>( + input, index, datap); + case WireFormatLite::TYPE_SFIXED64: + return ReadPrimitive<protobuf_int64, int64, + WireFormatLite::TYPE_SFIXED64>(input, index, datap); + case WireFormatLite::TYPE_SINT32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>( + input, index, datap); + case WireFormatLite::TYPE_SINT64: + return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>( + input, index, datap); + // default: intentionally omitted in order to enable static checking. + } + // Unreachable. + return errors::DataLoss("Failed reading unknown wire type"); +} + +// Reads and stores a length-delimited list of values. +inline Status ReadPackedFromArray(const void* buf, size_t buf_size, + const WireFormatLite::FieldType field_type, + const int field_number, const DataType dtype, + const int stride, int* index, void* data) { + // Dispatch to the appropriately typed field reader based on the + // schema type. + switch (field_type) { + case WireFormatLite::TYPE_DOUBLE: + *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_FLOAT: + *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_INT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_UINT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_INT32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_FIXED64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_FIXED32: + if (dtype == DataType::DT_INT64) { + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + } + if (dtype == DataType::DT_INT32) { + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + } + // Any case that reaches this point should have triggered an error + // already. + return errors::DataLoss("Failed reading TYPE_FIXED32"); + case WireFormatLite::TYPE_BOOL: + *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + case WireFormatLite::TYPE_BYTES: + return errors::DataLoss("Non-primitive type encountered as packed"); + case WireFormatLite::TYPE_UINT32: + if (dtype == DataType::DT_INT64) { + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + } + if (dtype == DataType::DT_INT32) { + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + } + // Any case that reaches this point should have triggered an error + // already. + return errors::DataLoss("Failed reading TYPE_UINT32"); + case WireFormatLite::TYPE_ENUM: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case WireFormatLite::TYPE_SFIXED32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + + case WireFormatLite::TYPE_SFIXED64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>( + buf, buf_size, *index, stride, data); + return Status::OK(); + + case WireFormatLite::TYPE_SINT32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + + case WireFormatLite::TYPE_SINT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>( + buf, buf_size, *index, stride, data); + return Status::OK(); + // default: intentionally omitted in order to enable static checking. + } + // Unreachable. + return errors::DataLoss("Failed reading unknown wire type"); +} + +// Reads a varint from the given buffer, write it to *value, and return the +// new buffer pointer. +// This was copied from coded_stream.cc where it is private. +// Important: This routine may read as much as kMaxVarintBytes from +// the buffer. It is the caller's responsibility to make sure that there is +// enough space in the buffer. +inline const uint8* ReadVarint64FromArray(const uint8* buffer, bool* ok, + uint64* value) { + const uint8* ptr = buffer; + uint32 b; + + // Splitting into 32-bit pieces gives better performance on 32-bit + // processors. + uint32 part0 = 0, part1 = 0, part2 = 0; + + b = *(ptr++); + part0 = b; + if (!(b & 0x80)) goto done; + part0 -= 0x80; + b = *(ptr++); + part0 += b << 7; + if (!(b & 0x80)) goto done; + part0 -= 0x80 << 7; + b = *(ptr++); + part0 += b << 14; + if (!(b & 0x80)) goto done; + part0 -= 0x80 << 14; + b = *(ptr++); + part0 += b << 21; + if (!(b & 0x80)) goto done; + part0 -= 0x80 << 21; + b = *(ptr++); + part1 = b; + if (!(b & 0x80)) goto done; + part1 -= 0x80; + b = *(ptr++); + part1 += b << 7; + if (!(b & 0x80)) goto done; + part1 -= 0x80 << 7; + b = *(ptr++); + part1 += b << 14; + if (!(b & 0x80)) goto done; + part1 -= 0x80 << 14; + b = *(ptr++); + part1 += b << 21; + if (!(b & 0x80)) goto done; + part1 -= 0x80 << 21; + b = *(ptr++); + part2 = b; + if (!(b & 0x80)) goto done; + part2 -= 0x80; + b = *(ptr++); + part2 += b << 7; + if (!(b & 0x80)) goto done; + // "part2 -= 0x80 << 7" is irrelevant because (0x80 << 7) << 56 is 0. + + // We have overrun the maximum size of a varint (10 bytes). Assume + // the data is corrupt. + *ok = false; + return ptr; + +done: + *ok = true; + *value = (static_cast<uint64>(part0)) | (static_cast<uint64>(part1) << 28) | + (static_cast<uint64>(part2) << 56); + return ptr; +} + +} // namespace internal +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_ diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.cc b/tensorflow/core/util/proto/descriptor_pool_registry.cc new file mode 100644 index 0000000000..5f0423f76b --- /dev/null +++ b/tensorflow/core/util/proto/descriptor_pool_registry.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "tensorflow/core/platform/logging.h" + +#include "tensorflow/core/util/proto/descriptor_pool_registry.h" + +namespace tensorflow { + +DescriptorPoolRegistry* DescriptorPoolRegistry::Global() { + static DescriptorPoolRegistry* registry = new DescriptorPoolRegistry; + return registry; +} + +DescriptorPoolRegistry::DescriptorPoolFn* DescriptorPoolRegistry::Get( + const string& source) { + auto found = fns_.find(source); + if (found == fns_.end()) return nullptr; + return &found->second; +} + +void DescriptorPoolRegistry::Register( + const string& source, + const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) { + auto existing = Get(source); + CHECK_EQ(existing, nullptr) + << "descriptor pool for source: " << source << " already registered"; + fns_.insert(std::pair<const string&, DescriptorPoolFn>(source, pool_fn)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/proto/descriptor_pool_registry.h b/tensorflow/core/util/proto/descriptor_pool_registry.h new file mode 100644 index 0000000000..66c20e9e41 --- /dev/null +++ b/tensorflow/core/util/proto/descriptor_pool_registry.h @@ -0,0 +1,76 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_ +#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_ + +#include <functional> +#include <memory> +#include <string> +#include <unordered_map> +#include <utility> + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +class DescriptorPoolRegistry { + public: + typedef std::function<Status( + tensorflow::protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool)> + DescriptorPoolFn; + + // Returns a pointer to a global DescriptorPoolRegistry object. + static DescriptorPoolRegistry* Global(); + + // Returns a pointer to a descriptor pool function for the given source. + DescriptorPoolFn* Get(const string& source); + + // Registers a descriptor pool factory. + void Register(const string& source, const DescriptorPoolFn& pool_fn); + + private: + std::map<string, DescriptorPoolFn> fns_; +}; + +namespace descriptor_pool_registration { + +class DescriptorPoolRegistration { + public: + DescriptorPoolRegistration( + const string& source, + const DescriptorPoolRegistry::DescriptorPoolFn& pool_fn) { + DescriptorPoolRegistry::Global()->Register(source, pool_fn); + } +}; + +} // namespace descriptor_pool_registration + +#define REGISTER_DESCRIPTOR_POOL(source, pool_fn) \ + REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(__COUNTER__, source, pool_fn) + +#define REGISTER_DESCRIPTOR_POOL_UNIQ_HELPER(ctr, source, pool_fn) \ + REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) + +#define REGISTER_DESCRIPTOR_POOL_UNIQ(ctr, source, pool_fn) \ + static descriptor_pool_registration::DescriptorPoolRegistration \ + descriptor_pool_registration_fn_##ctr(source, pool_fn) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTOR_POOL_REGISTRY_H_ diff --git a/tensorflow/core/util/proto/descriptor_pool_registry_test.cc b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc new file mode 100644 index 0000000000..a6899998ab --- /dev/null +++ b/tensorflow/core/util/proto/descriptor_pool_registry_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/proto/descriptor_pool_registry.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +struct Value { + static Status Function( + tensorflow::protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + return Status::OK(); + } +}; + +REGISTER_DESCRIPTOR_POOL("TEST POOL 1", Value::Function); +REGISTER_DESCRIPTOR_POOL("TEST POOL 2", Value::Function); +} // namespace + +TEST(DescriptorPoolRegistryTest, TestBasic) { + EXPECT_EQ(DescriptorPoolRegistry::Global()->Get("NON-EXISTENT"), nullptr); + auto pool1 = DescriptorPoolRegistry::Global()->Get("TEST POOL 1"); + EXPECT_NE(pool1, nullptr); + auto pool2 = DescriptorPoolRegistry::Global()->Get("TEST POOL 2"); + EXPECT_NE(pool2, nullptr); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/proto/descriptors.cc b/tensorflow/core/util/proto/descriptors.cc new file mode 100644 index 0000000000..271c85efd8 --- /dev/null +++ b/tensorflow/core/util/proto/descriptors.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/reader_op_kernel.h" +#include "tensorflow/core/util/proto/descriptor_pool_registry.h" + +#include "tensorflow/core/util/proto/descriptors.h" + +namespace tensorflow { +namespace { + +// Build a `DescriptorPool` from the named file or URI. The file or URI +// must be available to the current TensorFlow environment. +// +// The file must contiain a serialized `FileDescriptorSet`. See +// `GetDescriptorPool()` for more information. +Status GetDescriptorPoolFromFile( + tensorflow::Env* env, const string& filename, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + Status st = env->FileExists(filename); + if (!st.ok()) { + return st; + } + + // Read and parse the FileDescriptorSet. + tensorflow::protobuf::FileDescriptorSet descs; + std::unique_ptr<tensorflow::ReadOnlyMemoryRegion> buf; + st = env->NewReadOnlyMemoryRegionFromFile(filename, &buf); + if (!st.ok()) { + return st; + } + if (!descs.ParseFromArray(buf->data(), buf->length())) { + return errors::InvalidArgument( + "descriptor_source contains invalid FileDescriptorSet: ", filename); + } + + // Build a DescriptorPool from the FileDescriptorSet. + owned_desc_pool->reset(new tensorflow::protobuf::DescriptorPool()); + for (const auto& filedesc : descs.file()) { + if ((*owned_desc_pool)->BuildFile(filedesc) == nullptr) { + return errors::InvalidArgument( + "Problem loading FileDescriptorProto (missing dependencies?): ", + filename); + } + } + return Status::OK(); +} + +} // namespace + +Status GetDescriptorPool( + tensorflow::Env* env, string const& descriptor_source, + tensorflow::protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + // Attempt to lookup the pool in the registry. + auto pool_fn = DescriptorPoolRegistry::Global()->Get(descriptor_source); + if (pool_fn != nullptr) { + return (*pool_fn)(desc_pool, owned_desc_pool); + } + + // If there is no pool function registered for the given source, let the + // runtime find the file or URL. + Status status = + GetDescriptorPoolFromFile(env, descriptor_source, owned_desc_pool); + if (status.ok()) { + *desc_pool = owned_desc_pool->get(); + } + *desc_pool = owned_desc_pool->get(); + return status; +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/proto/descriptors.h b/tensorflow/core/util/proto/descriptors.h new file mode 100644 index 0000000000..92ee8997ab --- /dev/null +++ b/tensorflow/core/util/proto/descriptors.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_ +#define TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_ + +#include <memory> +#include <string> + +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +class Env; +class Status; + +// Get a `DescriptorPool` object from the named `descriptor_source`. +// `descriptor_source` may be a path to a file accessible to TensorFlow, in +// which case it is parsed as a `FileDescriptorSet` and used to build the +// `DescriptorPool`. +// +// `owned_desc_pool` will be filled in with the same pointer as `desc_pool` if +// the caller should take ownership. +extern tensorflow::Status GetDescriptorPool( + tensorflow::Env* env, string const& descriptor_source, + tensorflow::protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_PROTO_DESCRIPTORS_H_ diff --git a/tensorflow/core/util/proto/local_descriptor_pool_registration.cc b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc new file mode 100644 index 0000000000..48fe0102d0 --- /dev/null +++ b/tensorflow/core/util/proto/local_descriptor_pool_registration.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/util/proto/descriptor_pool_registry.h" + +namespace tensorflow { +namespace { + +struct LocalDescriptorPool { + static Status Function( + tensorflow::protobuf::DescriptorPool const** desc_pool, + std::unique_ptr<tensorflow::protobuf::DescriptorPool>* owned_desc_pool) { + *desc_pool = ::tensorflow::protobuf::DescriptorPool::generated_pool(); + if (*desc_pool == nullptr) { + return errors::InvalidArgument("Problem loading protobuf generated_pool"); + } + return Status::OK(); + } +}; + +REGISTER_DESCRIPTOR_POOL("", LocalDescriptorPool::Function); +REGISTER_DESCRIPTOR_POOL("local://", LocalDescriptorPool::Function); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/BUILD b/tensorflow/core/util/rpc/BUILD new file mode 100644 index 0000000000..f0f161ecc0 --- /dev/null +++ b/tensorflow/core/util/rpc/BUILD @@ -0,0 +1,48 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "call_container", + hdrs = ["call_container.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "rpc_factory", + srcs = ["rpc_factory.cc"], + hdrs = ["rpc_factory.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "rpc_factory_registry", + srcs = ["rpc_factory_registry.cc"], + hdrs = ["rpc_factory_registry.h"], + deps = [ + ":rpc_factory", + "//tensorflow/core:framework", + ], +) + +tf_cc_test( + name = "rpc_factory_registry_test", + srcs = ["rpc_factory_registry_test.cc"], + deps = [ + ":rpc_factory_registry", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h new file mode 100644 index 0000000000..7f36056797 --- /dev/null +++ b/tensorflow/core/util/rpc/call_container.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ +#define TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ + +#include <list> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/util/reffed_status_callback.h" + +namespace tensorflow { + +template <typename Call> +class CallContainer { + public: + explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast, + bool try_rpc, AsyncOpKernel::DoneCallback done, + CancellationToken token) + : ctx_(ctx), + done_(std::move(done)), + token_(token), + fail_fast_(fail_fast), + try_rpc_(try_rpc) { + CHECK_GT(num_calls, 0); + + // This will run when all RPCs are finished. + reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) { + ctx_->cancellation_manager()->DeregisterCallback(token_); + ctx_->SetStatus(s); + done_(); + delete this; + }); + + // Subtract reference count from the initial creation. + core::ScopedUnref unref(reffed_status_callback_); + + for (int i = 0; i < num_calls; ++i) { + // Increase the reference on the callback for each new RPC. + reffed_status_callback_->Ref(); + } + } + + std::list<Call>* calls() { return &calls_; } + + void StartCancel() { + // Once this loop is done, can no longer assume anything is valid + // because "delete this" may have been immediately called. + // Nothing should run after this loop. + for (auto& call : calls_) { + call.StartCancel(); + } + } + + void Done(const Status& s, int index) { + if (!try_rpc_) { + reffed_status_callback_->UpdateStatus(s); + } + reffed_status_callback_->Unref(); + } + + private: + OpKernelContext* ctx_; + std::list<Call> calls_; + const AsyncOpKernel::DoneCallback done_; + const CancellationToken token_; + const bool fail_fast_; + const bool try_rpc_; + + // Performs its own reference counting. + ReffedStatusCallback* reffed_status_callback_; +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory.cc b/tensorflow/core/util/rpc/rpc_factory.cc new file mode 100644 index 0000000000..8530f02b6e --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/lib/strings/numbers.h" + +#include "tensorflow/core/util/rpc/rpc_factory.h" + +namespace tensorflow { + +template <> +bool GetEnvVar(const char* key, const string& default_value, string* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + } else { + *value = env_value; + } + return true; +} + +template <> +bool GetEnvVar(const char* key, const int64& default_value, int64* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + return true; + } + return strings::safe_strto64(env_value, value); +} + +template <> +bool GetEnvVar(const char* key, const uint64& default_value, uint64* value) { + const char* env_value = std::getenv(key); + if (!env_value || env_value[0] == '\0') { + *value = default_value; + return true; + } + return strings::safe_strtou64(env_value, value); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h new file mode 100644 index 0000000000..9bf078c0f4 --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ +#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { + +// Return the environment variable `key`. If the variable is not set, +// use the default value. If it is set but could not be parsed, +// return `false`. Otherwise set `value` and return `true`. +template <typename T> +bool GetEnvVar(const char* key, const T& default_value, T* value); + +class RPCFactory { + public: + RPCFactory() {} + virtual ~RPCFactory() {} + + // Start a Call() to methods `method_t` at addresses `address_t` with + // request strings from `request_t`. Any of these may be scalar + // Tensors, in which case the operands are broadcasted. + // Upon completion of all requests, `response_t` will be populated. + // + // If `try_rpc` is `true`, then `status_message_t` and + // `status_code_t` will be populated as well. + // + // If `try_rpc` is `false`, then `status_message_t` and + // `status_code_t` are ignored (and may be nullptr). Instead, the + // status of any failed call will be propagated to the op. + // + // REQUIRES: + // - `response_t` is not null, and is a string Tensor with the same shape as + // `request_t`. + // + // If `try_rpc` is `true`: + // - `status_code_t` and `status_message_t` are not null. + // - `status_code_t` is an int32 Tensor with the same shape as + // `request_t`. + // - `status_message_t` is a string Tensor with the same shape as + // `request_t`. + virtual void Call(OpKernelContext* ctx, int64 num_elements, + const Tensor& address_t, const Tensor& method_t, + const Tensor& request_t, const bool try_rpc, + Tensor* response_t, Tensor* status_code_t, + Tensor* status_message_t, + AsyncOpKernel::DoneCallback done) = 0; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(RPCFactory); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.cc b/tensorflow/core/util/rpc/rpc_factory_registry.cc new file mode 100644 index 0000000000..a148b5c04d --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "tensorflow/core/util/rpc/rpc_factory.h" + +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" + +namespace tensorflow { + +RPCFactoryRegistry* RPCFactoryRegistry::Global() { + static RPCFactoryRegistry* registry = new RPCFactoryRegistry; + return registry; +} + +RPCFactoryRegistry::RPCFactoryFn* RPCFactoryRegistry::Get( + const string& protocol) { + auto found = fns_.find(protocol); + if (found == fns_.end()) return nullptr; + return &found->second; +} + +void RPCFactoryRegistry::Register(const string& protocol, + const RPCFactoryFn& factory_fn) { + auto existing = Get(protocol); + CHECK_EQ(existing, nullptr) + << "RPC factory for protocol: " << protocol << " already registered"; + fns_.insert(std::pair<const string&, RPCFactoryFn>(protocol, factory_fn)); +} + +} // namespace tensorflow diff --git a/tensorflow/core/util/rpc/rpc_factory_registry.h b/tensorflow/core/util/rpc/rpc_factory_registry.h new file mode 100644 index 0000000000..2635a4012e --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ +#define TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ + +#include <map> +#include <string> + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/rpc/rpc_factory.h" + +namespace tensorflow { + +class RPCFactoryRegistry { + public: + typedef std::function<RPCFactory*(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms)> + RPCFactoryFn; + + // Returns a pointer to a global RPCFactoryRegistry object. + static RPCFactoryRegistry* Global(); + + // Returns a pointer to an function that creates an RPC factory for the given + // protocol. + RPCFactoryFn* Get(const string& protocol); + + // Registers a function that creates and RPC factory for the given protocol. + // The function should transfer the ownership of the factory to its caller. + void Register(const string& protocol, const RPCFactoryFn& factory_fn); + + private: + std::map<string, RPCFactoryFn> fns_; +}; + +namespace rpc_factory_registration { + +class RPCFactoryRegistration { + public: + RPCFactoryRegistration(const string& protocol, + const RPCFactoryRegistry::RPCFactoryFn& factory_fn) { + RPCFactoryRegistry::Global()->Register(protocol, factory_fn); + } +}; + +} // namespace rpc_factory_registration + +#define REGISTER_RPC_FACTORY(protocol, factory_fn) \ + REGISTER_RPC_FACTORY_UNIQ_HELPER(__COUNTER__, protocol, factory_fn) + +#define REGISTER_RPC_FACTORY_UNIQ_HELPER(ctr, protocol, factory_fn) \ + REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) + +#define REGISTER_RPC_FACTORY_UNIQ(ctr, protocol, factory_fn) \ + static rpc_factory_registration::RPCFactoryRegistration \ + rpc_factory_registration_fn_##ctr(protocol, factory_fn) + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_UTIL_RPC_RPC_FACTORY_REGISTRY_H_ diff --git a/tensorflow/core/util/rpc/rpc_factory_registry_test.cc b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc new file mode 100644 index 0000000000..cfd0f95016 --- /dev/null +++ b/tensorflow/core/util/rpc/rpc_factory_registry_test.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/util/rpc/rpc_factory_registry.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +struct Value { + static RPCFactory* Function(OpKernelConstruction* ctx, bool fail_fast, + int64 timeout_in_ms) { + return nullptr; + } +}; + +REGISTER_RPC_FACTORY("TEST FACTORY 1", Value::Function); +REGISTER_RPC_FACTORY("TEST FACTORY 2", Value::Function); +} // namespace + +TEST(RPCFactoryRegistryTest, TestBasic) { + EXPECT_EQ(RPCFactoryRegistry::Global()->Get("NON-EXISTENT"), nullptr); + auto factory1 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 1"); + EXPECT_NE(factory1, nullptr); + auto factory2 = RPCFactoryRegistry::Global()->Get("TEST FACTORY 2"); + EXPECT_NE(factory2, nullptr); +} + +} // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 01962fcf44..a22b9f40b1 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3370,6 +3370,7 @@ tf_py_wrap_cc( "//tensorflow/c:python_api", "//tensorflow/c:tf_status_helper", "//tensorflow/c/eager:c_api", + "//tensorflow/core/distributed_runtime/rpc:grpc_rpc_factory_registration", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_session", "//tensorflow/core/grappler:grappler_item", |