diff options
Diffstat (limited to 'tensorflow/core/kernels/decode_proto_op.cc')
-rw-r--r-- | tensorflow/core/kernels/decode_proto_op.cc | 367 |
1 files changed, 157 insertions, 210 deletions
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc index 6d3dcc1c59..b54e1ea8ac 100644 --- a/tensorflow/core/kernels/decode_proto_op.cc +++ b/tensorflow/core/kernels/decode_proto_op.cc @@ -13,21 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// DecodeProto is a TensorFlow Op which extracts arbitrary fields -// from protos serialized as strings. +// DecodeProto is a TensorFlow op which extracts arbitrary fields from protos +// serialized as strings. // // See docs in ../ops/decode_proto_op.cc. // -// This implementation reads the serialized format using a handful of -// calls from the WireFormatLite API used by generated proto code. -// WireFormatLite is marked as an "internal" proto API but is widely -// used in practice and highly unlikely to change. -// This will be much faster than the previous implementation based on -// constructing a temporary dynamic message in memory and using the -// proto reflection api to read it. -// It can be used with any proto whose descriptors are available at -// runtime but should be competitive in speed with approaches that -// compile in the proto definitions. +// This implementation reads the serialized format using a handful of calls from +// the WireFormatLite API used by generated proto code. WireFormatLite is marked +// as an "internal" proto API but is widely used in practice and highly unlikely +// to change. This will be much faster than the previous implementation based on +// constructing a temporary dynamic message in memory and using the proto +// reflection api to read it. It can be used with any proto whose descriptors +// are available at runtime but should be competitive in speed with approaches +// that compile in the proto definitions. #include <memory> #include <string> @@ -36,11 +34,13 @@ limitations under the License. #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/proto/decode.h" #include "tensorflow/core/util/proto/descriptors.h" +#include "tensorflow/core/util/proto/proto_utils.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -58,53 +58,6 @@ using ::tensorflow::protobuf::io::CodedInputStream; const bool kFailOnDecodeError = true; -// Returns true if the proto field type can be converted to the -// tensorflow::DataType. -bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) { - switch (field_type) { - case WireFormatLite::TYPE_DOUBLE: - return output_type == tensorflow::DT_DOUBLE; - case WireFormatLite::TYPE_FLOAT: - return output_type == tensorflow::DT_FLOAT || - output_type == tensorflow::DT_DOUBLE; - case WireFormatLite::TYPE_INT64: - return output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_UINT64: - return output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_INT32: - return output_type == tensorflow::DT_INT32; - case WireFormatLite::TYPE_FIXED64: - return output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_FIXED32: - return output_type == tensorflow::DT_INT32 || - output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_BOOL: - return output_type == tensorflow::DT_BOOL; - case WireFormatLite::TYPE_STRING: - return output_type == tensorflow::DT_STRING; - case WireFormatLite::TYPE_GROUP: - return output_type == tensorflow::DT_STRING; - case WireFormatLite::TYPE_MESSAGE: - return output_type == tensorflow::DT_STRING; - case WireFormatLite::TYPE_BYTES: - return output_type == tensorflow::DT_STRING; - case WireFormatLite::TYPE_UINT32: - return output_type == tensorflow::DT_INT32 || - output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_ENUM: - return output_type == tensorflow::DT_INT32; - case WireFormatLite::TYPE_SFIXED32: - return output_type == tensorflow::DT_INT32; - case WireFormatLite::TYPE_SFIXED64: - return output_type == tensorflow::DT_INT64; - case WireFormatLite::TYPE_SINT32: - return output_type == tensorflow::DT_INT32; - case WireFormatLite::TYPE_SINT64: - return output_type == tensorflow::DT_INT64; - // default: intentionally omitted in order to enable static checking. - } -} - // Used to store the default value of a protocol message field, casted to the // type of the output tensor. // @@ -113,13 +66,15 @@ struct DefaultValue { DataType dtype = DataType::DT_INVALID; union Value { bool v_bool; // DT_BOOL - uint8 v_uint8; // DT_UINT8 + double v_double; // DT_DOUBLE + float v_float; // DT_FLOAT int8 v_int8; // DT_INT8 int32 v_int32; // DT_INT32 int64 v_int64; // DT_INT64 - float v_float; // DT_FLOAT - double v_double; // DT_DOUBLE const char* v_string; // DT_STRING + uint8 v_uint8; // DT_UINT8 + uint8 v_uint32; // DT_UINT32 + uint8 v_uint64; // DT_UINT64 }; Value value; }; @@ -138,23 +93,29 @@ Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) { case DT_BOOL: result->value.v_bool = static_cast<bool>(value); break; - case DT_INT32: - result->value.v_int32 = static_cast<int32>(value); + case DT_DOUBLE: + result->value.v_double = static_cast<double>(value); + break; + case DT_FLOAT: + result->value.v_float = static_cast<float>(value); break; case DT_INT8: result->value.v_int8 = static_cast<int8>(value); break; - case DT_UINT8: - result->value.v_uint8 = static_cast<uint8>(value); + case DT_INT32: + result->value.v_int32 = static_cast<int32>(value); break; case DT_INT64: result->value.v_int64 = static_cast<int64>(value); break; - case DT_FLOAT: - result->value.v_float = static_cast<float>(value); + case DT_UINT8: + result->value.v_uint8 = static_cast<uint8>(value); break; - case DT_DOUBLE: - result->value.v_double = static_cast<double>(value); + case DT_UINT32: + result->value.v_uint32 = static_cast<uint32>(value); + break; + case DT_UINT64: + result->value.v_uint64 = static_cast<uint64>(value); break; default: // We should never get here, given the type checking that occurs earlier. @@ -241,13 +202,11 @@ struct FieldInfo { number = field_desc->number(); // The wire format library defines the same constants used in - // descriptor.proto. This static_cast is safe because they - // are guaranteed to stay in sync. - // We need the field type from the FieldDescriptor here - // because the wire format doesn't tell us anything about - // what happens inside a packed repeated field: there is - // enough information in the wire format to skip the - // whole field but not enough to know how to parse what's + // descriptor.proto. This static_cast is safe because they are guaranteed to + // stay in sync. We need the field type from the FieldDescriptor here + // because the wire format doesn't tell us anything about what happens + // inside a packed repeated field: there is enough information in the wire + // format to skip the whole field but not enough to know how to parse what's // inside. For that we go to the schema. type = static_cast<WireFormatLite::FieldType>(field_desc->type()); is_repeated = field_desc->is_repeated(); @@ -257,16 +216,15 @@ struct FieldInfo { FieldInfo(const FieldInfo&) = delete; FieldInfo& operator=(const FieldInfo&) = delete; - // Internally we sort field descriptors by wire number for - // fast lookup. In general this is different from the order - // given by the user. Output_index gives the index into - // the field_names and output_types attributes and into + // Internally we sort field descriptors by wire number for fast lookup. In + // general this is different from the order given by the user. Output_index + // gives the index into the field_names and output_types attributes and into // the output tensor list. int output_index = -1; - // This is a cache of the relevant fields from `FieldDescriptorProto`. - // This was added after noticing that FieldDescriptor->type() was - // using 6% of the cpu profile. + // This is a cache of the relevant fields from `FieldDescriptorProto`. This + // was added after noticing that FieldDescriptor->type() was using 6% of the + // cpu profile. WireFormatLite::FieldType type; int number; bool is_repeated; @@ -275,16 +233,16 @@ struct FieldInfo { // A CountCollector counts sizes of repeated and optional fields in a proto. // -// Each field is tracked by a single CountCollector instance. The -// instance manages a single count, which is stored as a pointer (it -// is intended to be a reference to the `sizes` output which is being -// filled in). The pointer is passed in at initialization. +// Each field is tracked by a single CountCollector instance. The instance +// manages a single count, which is stored as a pointer (it is intended to be a +// reference to the `sizes` output which is being filled in). The pointer is +// passed in at initialization. // -// Counting is done as a separate pass in order to allocate output tensors -// all at once. This allows the TensorFlow runtime to optimize allocation -// for the consumer, while removing the need for copying inside this op. -// After this pass, the DenseCollector class (below) gathers the data: -// It is more complex and provides better motivation for the API here. +// Counting is done as a separate pass in order to allocate output tensors all +// at once. This allows the TensorFlow runtime to optimize allocation for the +// consumer, while removing the need for copying inside this op. After this +// pass, the DenseCollector class (below) gathers the data: it is more complex +// and provides better motivation for the API here. class CountCollector { public: CountCollector() = delete; @@ -298,8 +256,8 @@ class CountCollector { if (*count_ptr_ == 0 || field.is_repeated) { (*count_ptr_)++; } - // We expect a wire type based on the schema field_type, to allow - // a little more checking. + // We expect a wire type based on the schema field_type, to allow a little + // more checking. if (!SkipValue(input, field)) { return errors::DataLoss("ReadValue: Failed skipping field when counting"); } @@ -329,8 +287,8 @@ class CountCollector { return errors::DataLoss("ReadPackedValues: Skipping packed field failed"); } - // Dispatch to the appropriately typed field reader based on the - // schema type. + // Dispatch to the appropriately typed field reader based on the schema + // type. Status st; switch (field.type) { case WireFormatLite::TYPE_DOUBLE: @@ -409,18 +367,17 @@ class CountCollector { return input->Skip(length); } - // Counts the number of packed varints in an array. - // The end of a varint is signaled by a value < 0x80, - // so counting them requires parsing the bytestream. - // It is the caller's responsibility to ensure that len > 0. + // Counts the number of packed varints in an array. The end of a varint is + // signaled by a value < 0x80, so counting them requires parsing the + // bytestream. It is the caller's responsibility to ensure that len > 0. Status CountPackedVarint(const uint8* buf, size_t len) { const uint8* bound = buf + len; int count; - // The last byte in a valid encoded varint is guaranteed to have - // the high bit unset. We rely on this property to prevent - // ReadVarint64FromArray from going out of bounds, so validate - // the end of the buf before scanning anything. + // The last byte in a valid encoded varint is guaranteed to have the high + // bit unset. We rely on this property to prevent ReadVarint64FromArray from + // going out of bounds, so validate the end of the buf before scanning + // anything. if (bound[-1] & 0x80) { return errors::DataLoss("Corrupt packed varint"); } @@ -439,8 +396,8 @@ class CountCollector { return Status::OK(); } - // Counts the number of fixed-size values in a packed field. - // This can be done without actually parsing anything. + // Counts the number of fixed-size values in a packed field. This can be done + // without actually parsing anything. template <typename T> Status CountPackedFixed(const uint8* unused_buf, size_t len) { int count = len / sizeof(T); @@ -452,10 +409,9 @@ class CountCollector { return Status::OK(); } - // Skips a single value in the input stream. - // Dispatches to the appropriately typed field skipper based on the - // schema type tag. - // This is not as permissive as just handling the wire type. + // Skips a single value in the input stream. Dispatches to the appropriately + // typed field skipper based on the schema type tag. This is not as permissive + // as just handling the wire type. static bool SkipValue(CodedInputStream* input, const FieldInfo& field) { uint32 tmp32; protobuf_uint64 tmp64; @@ -507,13 +463,13 @@ class CountCollector { // A DenseCollector accumulates values from a proto into a tensor. // -// There is an instance of DenseCollector for each field of each -// proto. The DenseCollector deserializes the value from the wire -// directly into the preallocated output Tensor. +// There is an instance of DenseCollector for each field of each proto. The +// DenseCollector deserializes the value from the wire directly into the +// preallocated output Tensor. // -// This class is named DenseCollector because in the future there should -// be a SparseCollector that accumulates field data into sparse tensors if -// the user requests it. +// This class is named DenseCollector because in the future there should be a +// SparseCollector that accumulates field data into sparse tensors if the user +// requests it. class DenseCollector { public: DenseCollector() = delete; @@ -578,40 +534,43 @@ class DenseCollector { } } - // Fills in any missing values in the output array with defaults. - // Dispatches to the appropriately typed field default based on the - // runtime type tag. + // Fills in any missing values in the output array with defaults. Dispatches + // to the appropriately typed field default based on the runtime type tag. Status FillWithDefaults() { switch (default_value_.dtype) { + case DataType::DT_BOOL: + return FillDefault<bool>(default_value_.value.v_bool); case DataType::DT_FLOAT: return FillDefault<float>(default_value_.value.v_float); case DataType::DT_DOUBLE: return FillDefault<double>(default_value_.value.v_double); - case DataType::DT_INT32: - return FillDefault<int32>(default_value_.value.v_int32); - case DataType::DT_UINT8: - return FillDefault<uint8>(default_value_.value.v_uint8); case DataType::DT_INT8: return FillDefault<int8>(default_value_.value.v_int8); - case DataType::DT_STRING: - return FillDefault<string>(default_value_.value.v_string); + case DataType::DT_INT32: + return FillDefault<int32>(default_value_.value.v_int32); case DataType::DT_INT64: return FillDefault<int64>(default_value_.value.v_int64); - case DataType::DT_BOOL: - return FillDefault<bool>(default_value_.value.v_bool); + case DataType::DT_STRING: + return FillDefault<string>(default_value_.value.v_string); + case DataType::DT_UINT8: + return FillDefault<uint8>(default_value_.value.v_uint8); + case DataType::DT_UINT32: + return FillDefault<uint32>(default_value_.value.v_uint32); + case DataType::DT_UINT64: + return FillDefault<uint64>(default_value_.value.v_uint64); default: // There are many tensorflow dtypes not handled here, but they // should not come up unless type casting is added to the Op. // Chaining with tf.cast() should do the right thing until then. - return errors::DataLoss( - "Failed filling defaults in unknown tf::DataType"); + return errors::DataLoss("Failed filling defaults for ", + DataTypeString(default_value_.dtype)); } } private: - // Fills empty values in the dense representation with a - // default value. This uses next_repeat_index_ which counts the number - // of parsed values for the field. + // Fills empty values in the dense representation with a default value. This + // uses next_repeat_index_ which counts the number of parsed values for the + // field. template <class T> Status FillDefault(const T& default_value) { for (int i = next_repeat_index_; i < max_repeat_count_; i++) { @@ -622,11 +581,10 @@ class DenseCollector { int32 next_repeat_index_ = 0; - // This is a pointer to data_[message_index_]. - // There is no bounds checking at this level: we computed the max - // repeat size for each field in CountCollector and use the same - // code to traverse it here, so we are guaranteed not to be called - // for more items than we have allocated space. + // This is a pointer to data_[message_index_]. There is no bounds checking at + // this level: we computed the max repeat size for each field in + // CountCollector and use the same code to traverse it here, so we are + // guaranteed not to be called for more items than we have allocated space. void* const datap_ = nullptr; const DefaultValue default_value_; @@ -665,7 +623,6 @@ class DecodeProtoOp : public OpKernel { "have the same length")); // Gather the field descriptors and check that requested output types match. - int field_index = 0; std::vector<const FieldDescriptor*> field_descs; for (const string& name : field_names) { @@ -673,18 +630,16 @@ class DecodeProtoOp : public OpKernel { OP_REQUIRES(context, fd != nullptr, errors::InvalidArgument("Unknown field: ", name, " in message type ", message_type)); - OP_REQUIRES(context, - CheckOutputType(fd->type(), output_types[field_index]), - // Many TensorFlow types don't have corresponding proto types - // and the user will get an error if they are requested. It - // would be nice to allow conversions here, but tf.cast - // already exists so we don't duplicate the functionality. - // Known unhandled types: - // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32 - // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16 - errors::InvalidArgument("Unexpected output type for ", - fd->full_name(), ": ", fd->cpp_type(), - " to ", output_types[field_index])); + OP_REQUIRES( + context, + proto_utils::IsCompatibleType(fd->type(), output_types[field_index]), + // Many TensorFlow types don't have corresponding proto types and the + // user will get an error if they are requested. It would be nice to + // allow conversions here, but tf.cast already exists so we don't + // duplicate the functionality. + errors::InvalidArgument("Unexpected output type for ", + fd->full_name(), ": ", fd->cpp_type(), " to ", + output_types[field_index])); field_index++; field_descs.push_back(fd); @@ -726,10 +681,9 @@ class DecodeProtoOp : public OpKernel { errors::InvalidArgument("format must be one of binary or text")); is_binary_ = format == "binary"; - // Enable the initial protobuf sanitizer, which is much - // more expensive than the decoder. - // TODO(nix): Remove this once the fast decoder - // has passed security review. + // Enable the initial protobuf sanitizer, which is much more expensive than + // the decoder. + // TODO(nix): Remove this once the fast decoder has passed security review. OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_)); } @@ -742,9 +696,9 @@ class DecodeProtoOp : public OpKernel { int field_count = fields_.size(); - // Save the argument shape for later, then flatten the input - // Tensor since we are working componentwise. We will restore - // the same shape in the returned Tensor. + // Save the argument shape for later, then flatten the input Tensor since we + // are working componentwise. We will restore the same shape in the returned + // Tensor. const TensorShape& shape_prefix = buf_tensor.shape(); TensorShape sizes_shape = shape_prefix; @@ -752,8 +706,8 @@ class DecodeProtoOp : public OpKernel { Tensor* sizes_tensor = nullptr; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor)); - // This is used to allocate binary bufs if used. It serves only - // to define memory ownership. + // This is used to allocate binary bufs if used. It serves only to define + // memory ownership. std::vector<string> tmp_binary_bufs(message_count); // These are the actual buffers to use, which may be in tmp_binary_bufs @@ -768,8 +722,8 @@ class DecodeProtoOp : public OpKernel { bufs.push_back(buf); } } else { - // We will have to allocate a copy, either to convert from text to - // binary or to sanitize a binary proto. + // We will have to allocate a copy, either to convert from text to binary + // or to sanitize a binary proto. for (int mi = 0; mi < message_count; ++mi) { ReserializeMessage(ctx, buf_tensor.flat<string>()(mi), &tmp_binary_bufs[mi]); @@ -780,16 +734,14 @@ class DecodeProtoOp : public OpKernel { } } - // Walk through all the strings in the input tensor, counting - // the number of fields in each. - // We can't allocate our actual output Tensor until we know the - // maximum repeat count, so we do a first pass through the serialized - // proto just counting fields. - // We always allocate at least one value so that optional fields - // are populated with default values - this avoids a TF - // conditional when handling the output data. - // The caller can distinguish between real data and defaults - // using the repeat count matrix that is returned by decode_proto. + // Walk through all the strings in the input tensor, counting the number of + // fields in each. We can't allocate our actual output Tensor until we know + // the maximum repeat count, so we do a first pass through the serialized + // proto just counting fields. We always allocate at least one value so that + // optional fields are populated with default values - this avoids a TF + // conditional when handling the output data. The caller can distinguish + // between real data and defaults using the repeat count matrix that is + // returned by decode_proto. std::vector<int32> max_sizes(field_count, 1); for (int mi = 0; mi < message_count; ++mi) { CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes); @@ -814,14 +766,12 @@ class DecodeProtoOp : public OpKernel { // REGISTER_OP(...) // .Attr("output_types: list(type) >= 0") // .Output("values: output_types") - OP_REQUIRES_OK(ctx, - // ctx->allocate_output(output_indices_[fi] + 1, - ctx->allocate_output(fields_[fi]->output_index + 1, - out_shape, &outputs[fi])); + OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1, + out_shape, &outputs[fi])); } - // Make the second pass through the serialized proto, decoding - // into preallocated tensors. + // Make the second pass through the serialized proto, decoding into + // preallocated tensors. AccumulateFields(ctx, bufs, outputs); } @@ -976,6 +926,7 @@ class DecodeProtoOp : public OpKernel { // Look up the FieldDescriptor for a particular field number. bool LookupField(int field_number, int* field_index) { // Look up the FieldDescriptor using linear search. + // // TODO(nix): this could be sped up with binary search, but we are // already way off the fastpath at this point. If you see a hotspot // here, somebody is sending you very inefficient protos. @@ -1010,6 +961,7 @@ class DecodeProtoOp : public OpKernel { // This takes advantage of the sorted field numbers in most serialized // protos: it tries the next expected field first rather than doing // a lookup by field number. + // // TODO(nix): haberman@ suggests a hybrid approach with a lookup table // for small field numbers and a hash table for larger ones. This would // be a simpler approach that should offer comparable speed in most @@ -1029,9 +981,9 @@ class DecodeProtoOp : public OpKernel { last_good_field_index = field_index; } } else { - // If we see a field that is past the next field we want, - // it was empty. Look for the one after that. - // Repeat until we run out of fields that we care about. + // If we see a field that is past the next field we want, it was + // empty. Look for the one after that. Repeat until we run out of + // fields that we care about. while (field_number >= next_good_field_number) { if (field_number == next_good_field_number) { last_good_field_number = field_number; @@ -1044,10 +996,9 @@ class DecodeProtoOp : public OpKernel { next_good_field_number = fields_[last_good_field_index + 1]->number; } else { - // Saw something past the last field we care about. - // Continue parsing the message just in case there - // are disordered fields later, but any remaining - // ordered fields will have no effect. + // Saw something past the last field we care about. Continue + // parsing the message just in case there are disordered fields + // later, but any remaining ordered fields will have no effect. next_good_field_number = INT_MAX; } } @@ -1077,20 +1028,20 @@ class DecodeProtoOp : public OpKernel { WireFormatLite::WireType wire_type, CodedInputStream* input, CollectorClass* collector) { // The wire format library defines the same constants used in - // descriptor.proto. This static_cast is safe because they - // are guaranteed to stay in sync. - // We need the field type from the FieldDescriptor here - // because the wire format doesn't tell us anything about - // what happens inside a packed repeated field: there is - // enough information in the wire format to skip the - // whole field but not enough to know how to parse what's - // inside. For that we go to the schema. + // descriptor.proto. This static_cast is safe because they are guaranteed to + // stay in sync. + // + // We need the field type from the FieldDescriptor here because the wire + // format doesn't tell us anything about what happens inside a packed + // repeated field: there is enough information in the wire format to skip + // the whole field but not enough to know how to parse what's inside. For + // that we go to the schema. WireFormatLite::WireType schema_wire_type = WireFormatLite::WireTypeForFieldType(field.type); - // Handle packed repeated fields. SkipField would skip the - // whole length-delimited blob without letting us count the - // values, so we have to scan them ourselves. + // Handle packed repeated fields. SkipField would skip the whole + // length-delimited blob without letting us count the values, so we have to + // scan them ourselves. if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED && schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { // Handle packed repeated primitives. @@ -1098,11 +1049,7 @@ class DecodeProtoOp : public OpKernel { if (!input->ReadVarintSizeAsInt(&length)) { return errors::DataLoss("CollectField: Failed reading packed size"); } - Status st = collector->ReadPackedValues(input, field, length); - if (!st.ok()) { - return st; - } - return Status::OK(); + return collector->ReadPackedValues(input, field, length); } // Read ordinary values, including strings, bytes, and messages. @@ -1118,9 +1065,9 @@ class DecodeProtoOp : public OpKernel { } string message_type_; - // Note that fields are sorted by increasing field number, - // which is not in general the order given by the user-specified - // field_names and output_types Op attributes. + // Note that fields are sorted by increasing field number, which is not in + // general the order given by the user-specified field_names and output_types + // Op attributes. std::vector<std::unique_ptr<const FieldInfo>> fields_; // Owned_desc_pool_ is null when using descriptor_source=local. @@ -1131,12 +1078,12 @@ class DecodeProtoOp : public OpKernel { // True if decoding binary format, false if decoding text format. bool is_binary_; - // True if the protos should be sanitized before parsing. - // Enables the initial protobuf sanitizer, which is much - // more expensive than the decoder. The flag defaults to true - // but can be set to false for trusted sources. - // TODO(nix): flip the default to false when the fast decoder - // has passed security review. + // True if the protos should be sanitized before parsing. Enables the initial + // protobuf sanitizer, which is much more expensive than the decoder. The flag + // defaults to true but can be set to false for trusted sources. + // + // TODO(nix): Flip the default to false when the fast decoder has passed + // security review. bool sanitize_; TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp); |