aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/decode_proto_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/decode_proto_op.cc')
-rw-r--r--tensorflow/core/kernels/decode_proto_op.cc367
1 files changed, 157 insertions, 210 deletions
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
index 6d3dcc1c59..b54e1ea8ac 100644
--- a/tensorflow/core/kernels/decode_proto_op.cc
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -13,21 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// DecodeProto is a TensorFlow Op which extracts arbitrary fields
-// from protos serialized as strings.
+// DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
+// serialized as strings.
//
// See docs in ../ops/decode_proto_op.cc.
//
-// This implementation reads the serialized format using a handful of
-// calls from the WireFormatLite API used by generated proto code.
-// WireFormatLite is marked as an "internal" proto API but is widely
-// used in practice and highly unlikely to change.
-// This will be much faster than the previous implementation based on
-// constructing a temporary dynamic message in memory and using the
-// proto reflection api to read it.
-// It can be used with any proto whose descriptors are available at
-// runtime but should be competitive in speed with approaches that
-// compile in the proto definitions.
+// This implementation reads the serialized format using a handful of calls from
+// the WireFormatLite API used by generated proto code. WireFormatLite is marked
+// as an "internal" proto API but is widely used in practice and highly unlikely
+// to change. This will be much faster than the previous implementation based on
+// constructing a temporary dynamic message in memory and using the proto
+// reflection api to read it. It can be used with any proto whose descriptors
+// are available at runtime but should be competitive in speed with approaches
+// that compile in the proto definitions.
#include <memory>
#include <string>
@@ -36,11 +34,13 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/proto/decode.h"
#include "tensorflow/core/util/proto/descriptors.h"
+#include "tensorflow/core/util/proto/proto_utils.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -58,53 +58,6 @@ using ::tensorflow::protobuf::io::CodedInputStream;
const bool kFailOnDecodeError = true;
-// Returns true if the proto field type can be converted to the
-// tensorflow::DataType.
-bool CheckOutputType(FieldDescriptor::Type field_type, DataType output_type) {
- switch (field_type) {
- case WireFormatLite::TYPE_DOUBLE:
- return output_type == tensorflow::DT_DOUBLE;
- case WireFormatLite::TYPE_FLOAT:
- return output_type == tensorflow::DT_FLOAT ||
- output_type == tensorflow::DT_DOUBLE;
- case WireFormatLite::TYPE_INT64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_UINT64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_INT32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_FIXED64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_FIXED32:
- return output_type == tensorflow::DT_INT32 ||
- output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_BOOL:
- return output_type == tensorflow::DT_BOOL;
- case WireFormatLite::TYPE_STRING:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_GROUP:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_MESSAGE:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_BYTES:
- return output_type == tensorflow::DT_STRING;
- case WireFormatLite::TYPE_UINT32:
- return output_type == tensorflow::DT_INT32 ||
- output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_ENUM:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SFIXED32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SFIXED64:
- return output_type == tensorflow::DT_INT64;
- case WireFormatLite::TYPE_SINT32:
- return output_type == tensorflow::DT_INT32;
- case WireFormatLite::TYPE_SINT64:
- return output_type == tensorflow::DT_INT64;
- // default: intentionally omitted in order to enable static checking.
- }
-}
-
// Used to store the default value of a protocol message field, casted to the
// type of the output tensor.
//
@@ -113,13 +66,15 @@ struct DefaultValue {
DataType dtype = DataType::DT_INVALID;
union Value {
bool v_bool; // DT_BOOL
- uint8 v_uint8; // DT_UINT8
+ double v_double; // DT_DOUBLE
+ float v_float; // DT_FLOAT
int8 v_int8; // DT_INT8
int32 v_int32; // DT_INT32
int64 v_int64; // DT_INT64
- float v_float; // DT_FLOAT
- double v_double; // DT_DOUBLE
const char* v_string; // DT_STRING
+ uint8 v_uint8; // DT_UINT8
+ uint8 v_uint32; // DT_UINT32
+ uint8 v_uint64; // DT_UINT64
};
Value value;
};
@@ -138,23 +93,29 @@ Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
case DT_BOOL:
result->value.v_bool = static_cast<bool>(value);
break;
- case DT_INT32:
- result->value.v_int32 = static_cast<int32>(value);
+ case DT_DOUBLE:
+ result->value.v_double = static_cast<double>(value);
+ break;
+ case DT_FLOAT:
+ result->value.v_float = static_cast<float>(value);
break;
case DT_INT8:
result->value.v_int8 = static_cast<int8>(value);
break;
- case DT_UINT8:
- result->value.v_uint8 = static_cast<uint8>(value);
+ case DT_INT32:
+ result->value.v_int32 = static_cast<int32>(value);
break;
case DT_INT64:
result->value.v_int64 = static_cast<int64>(value);
break;
- case DT_FLOAT:
- result->value.v_float = static_cast<float>(value);
+ case DT_UINT8:
+ result->value.v_uint8 = static_cast<uint8>(value);
break;
- case DT_DOUBLE:
- result->value.v_double = static_cast<double>(value);
+ case DT_UINT32:
+ result->value.v_uint32 = static_cast<uint32>(value);
+ break;
+ case DT_UINT64:
+ result->value.v_uint64 = static_cast<uint64>(value);
break;
default:
// We should never get here, given the type checking that occurs earlier.
@@ -241,13 +202,11 @@ struct FieldInfo {
number = field_desc->number();
// The wire format library defines the same constants used in
- // descriptor.proto. This static_cast is safe because they
- // are guaranteed to stay in sync.
- // We need the field type from the FieldDescriptor here
- // because the wire format doesn't tell us anything about
- // what happens inside a packed repeated field: there is
- // enough information in the wire format to skip the
- // whole field but not enough to know how to parse what's
+ // descriptor.proto. This static_cast is safe because they are guaranteed to
+ // stay in sync. We need the field type from the FieldDescriptor here
+ // because the wire format doesn't tell us anything about what happens
+ // inside a packed repeated field: there is enough information in the wire
+ // format to skip the whole field but not enough to know how to parse what's
// inside. For that we go to the schema.
type = static_cast<WireFormatLite::FieldType>(field_desc->type());
is_repeated = field_desc->is_repeated();
@@ -257,16 +216,15 @@ struct FieldInfo {
FieldInfo(const FieldInfo&) = delete;
FieldInfo& operator=(const FieldInfo&) = delete;
- // Internally we sort field descriptors by wire number for
- // fast lookup. In general this is different from the order
- // given by the user. Output_index gives the index into
- // the field_names and output_types attributes and into
+ // Internally we sort field descriptors by wire number for fast lookup. In
+ // general this is different from the order given by the user. Output_index
+ // gives the index into the field_names and output_types attributes and into
// the output tensor list.
int output_index = -1;
- // This is a cache of the relevant fields from `FieldDescriptorProto`.
- // This was added after noticing that FieldDescriptor->type() was
- // using 6% of the cpu profile.
+ // This is a cache of the relevant fields from `FieldDescriptorProto`. This
+ // was added after noticing that FieldDescriptor->type() was using 6% of the
+ // cpu profile.
WireFormatLite::FieldType type;
int number;
bool is_repeated;
@@ -275,16 +233,16 @@ struct FieldInfo {
// A CountCollector counts sizes of repeated and optional fields in a proto.
//
-// Each field is tracked by a single CountCollector instance. The
-// instance manages a single count, which is stored as a pointer (it
-// is intended to be a reference to the `sizes` output which is being
-// filled in). The pointer is passed in at initialization.
+// Each field is tracked by a single CountCollector instance. The instance
+// manages a single count, which is stored as a pointer (it is intended to be a
+// reference to the `sizes` output which is being filled in). The pointer is
+// passed in at initialization.
//
-// Counting is done as a separate pass in order to allocate output tensors
-// all at once. This allows the TensorFlow runtime to optimize allocation
-// for the consumer, while removing the need for copying inside this op.
-// After this pass, the DenseCollector class (below) gathers the data:
-// It is more complex and provides better motivation for the API here.
+// Counting is done as a separate pass in order to allocate output tensors all
+// at once. This allows the TensorFlow runtime to optimize allocation for the
+// consumer, while removing the need for copying inside this op. After this
+// pass, the DenseCollector class (below) gathers the data: it is more complex
+// and provides better motivation for the API here.
class CountCollector {
public:
CountCollector() = delete;
@@ -298,8 +256,8 @@ class CountCollector {
if (*count_ptr_ == 0 || field.is_repeated) {
(*count_ptr_)++;
}
- // We expect a wire type based on the schema field_type, to allow
- // a little more checking.
+ // We expect a wire type based on the schema field_type, to allow a little
+ // more checking.
if (!SkipValue(input, field)) {
return errors::DataLoss("ReadValue: Failed skipping field when counting");
}
@@ -329,8 +287,8 @@ class CountCollector {
return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
}
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema
+ // type.
Status st;
switch (field.type) {
case WireFormatLite::TYPE_DOUBLE:
@@ -409,18 +367,17 @@ class CountCollector {
return input->Skip(length);
}
- // Counts the number of packed varints in an array.
- // The end of a varint is signaled by a value < 0x80,
- // so counting them requires parsing the bytestream.
- // It is the caller's responsibility to ensure that len > 0.
+ // Counts the number of packed varints in an array. The end of a varint is
+ // signaled by a value < 0x80, so counting them requires parsing the
+ // bytestream. It is the caller's responsibility to ensure that len > 0.
Status CountPackedVarint(const uint8* buf, size_t len) {
const uint8* bound = buf + len;
int count;
- // The last byte in a valid encoded varint is guaranteed to have
- // the high bit unset. We rely on this property to prevent
- // ReadVarint64FromArray from going out of bounds, so validate
- // the end of the buf before scanning anything.
+ // The last byte in a valid encoded varint is guaranteed to have the high
+ // bit unset. We rely on this property to prevent ReadVarint64FromArray from
+ // going out of bounds, so validate the end of the buf before scanning
+ // anything.
if (bound[-1] & 0x80) {
return errors::DataLoss("Corrupt packed varint");
}
@@ -439,8 +396,8 @@ class CountCollector {
return Status::OK();
}
- // Counts the number of fixed-size values in a packed field.
- // This can be done without actually parsing anything.
+ // Counts the number of fixed-size values in a packed field. This can be done
+ // without actually parsing anything.
template <typename T>
Status CountPackedFixed(const uint8* unused_buf, size_t len) {
int count = len / sizeof(T);
@@ -452,10 +409,9 @@ class CountCollector {
return Status::OK();
}
- // Skips a single value in the input stream.
- // Dispatches to the appropriately typed field skipper based on the
- // schema type tag.
- // This is not as permissive as just handling the wire type.
+ // Skips a single value in the input stream. Dispatches to the appropriately
+ // typed field skipper based on the schema type tag. This is not as permissive
+ // as just handling the wire type.
static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
uint32 tmp32;
protobuf_uint64 tmp64;
@@ -507,13 +463,13 @@ class CountCollector {
// A DenseCollector accumulates values from a proto into a tensor.
//
-// There is an instance of DenseCollector for each field of each
-// proto. The DenseCollector deserializes the value from the wire
-// directly into the preallocated output Tensor.
+// There is an instance of DenseCollector for each field of each proto. The
+// DenseCollector deserializes the value from the wire directly into the
+// preallocated output Tensor.
//
-// This class is named DenseCollector because in the future there should
-// be a SparseCollector that accumulates field data into sparse tensors if
-// the user requests it.
+// This class is named DenseCollector because in the future there should be a
+// SparseCollector that accumulates field data into sparse tensors if the user
+// requests it.
class DenseCollector {
public:
DenseCollector() = delete;
@@ -578,40 +534,43 @@ class DenseCollector {
}
}
- // Fills in any missing values in the output array with defaults.
- // Dispatches to the appropriately typed field default based on the
- // runtime type tag.
+ // Fills in any missing values in the output array with defaults. Dispatches
+ // to the appropriately typed field default based on the runtime type tag.
Status FillWithDefaults() {
switch (default_value_.dtype) {
+ case DataType::DT_BOOL:
+ return FillDefault<bool>(default_value_.value.v_bool);
case DataType::DT_FLOAT:
return FillDefault<float>(default_value_.value.v_float);
case DataType::DT_DOUBLE:
return FillDefault<double>(default_value_.value.v_double);
- case DataType::DT_INT32:
- return FillDefault<int32>(default_value_.value.v_int32);
- case DataType::DT_UINT8:
- return FillDefault<uint8>(default_value_.value.v_uint8);
case DataType::DT_INT8:
return FillDefault<int8>(default_value_.value.v_int8);
- case DataType::DT_STRING:
- return FillDefault<string>(default_value_.value.v_string);
+ case DataType::DT_INT32:
+ return FillDefault<int32>(default_value_.value.v_int32);
case DataType::DT_INT64:
return FillDefault<int64>(default_value_.value.v_int64);
- case DataType::DT_BOOL:
- return FillDefault<bool>(default_value_.value.v_bool);
+ case DataType::DT_STRING:
+ return FillDefault<string>(default_value_.value.v_string);
+ case DataType::DT_UINT8:
+ return FillDefault<uint8>(default_value_.value.v_uint8);
+ case DataType::DT_UINT32:
+ return FillDefault<uint32>(default_value_.value.v_uint32);
+ case DataType::DT_UINT64:
+ return FillDefault<uint64>(default_value_.value.v_uint64);
default:
// There are many tensorflow dtypes not handled here, but they
// should not come up unless type casting is added to the Op.
// Chaining with tf.cast() should do the right thing until then.
- return errors::DataLoss(
- "Failed filling defaults in unknown tf::DataType");
+ return errors::DataLoss("Failed filling defaults for ",
+ DataTypeString(default_value_.dtype));
}
}
private:
- // Fills empty values in the dense representation with a
- // default value. This uses next_repeat_index_ which counts the number
- // of parsed values for the field.
+ // Fills empty values in the dense representation with a default value. This
+ // uses next_repeat_index_ which counts the number of parsed values for the
+ // field.
template <class T>
Status FillDefault(const T& default_value) {
for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
@@ -622,11 +581,10 @@ class DenseCollector {
int32 next_repeat_index_ = 0;
- // This is a pointer to data_[message_index_].
- // There is no bounds checking at this level: we computed the max
- // repeat size for each field in CountCollector and use the same
- // code to traverse it here, so we are guaranteed not to be called
- // for more items than we have allocated space.
+ // This is a pointer to data_[message_index_]. There is no bounds checking at
+ // this level: we computed the max repeat size for each field in
+ // CountCollector and use the same code to traverse it here, so we are
+ // guaranteed not to be called for more items than we have allocated space.
void* const datap_ = nullptr;
const DefaultValue default_value_;
@@ -665,7 +623,6 @@ class DecodeProtoOp : public OpKernel {
"have the same length"));
// Gather the field descriptors and check that requested output types match.
-
int field_index = 0;
std::vector<const FieldDescriptor*> field_descs;
for (const string& name : field_names) {
@@ -673,18 +630,16 @@ class DecodeProtoOp : public OpKernel {
OP_REQUIRES(context, fd != nullptr,
errors::InvalidArgument("Unknown field: ", name,
" in message type ", message_type));
- OP_REQUIRES(context,
- CheckOutputType(fd->type(), output_types[field_index]),
- // Many TensorFlow types don't have corresponding proto types
- // and the user will get an error if they are requested. It
- // would be nice to allow conversions here, but tf.cast
- // already exists so we don't duplicate the functionality.
- // Known unhandled types:
- // DT_INT16 DT_COMPLEX64 DT_QINT8 DT_QUINT8 DT_QINT32
- // DT_BFLOAT16 DT_QINT16 DT_QUINT16 DT_UINT16
- errors::InvalidArgument("Unexpected output type for ",
- fd->full_name(), ": ", fd->cpp_type(),
- " to ", output_types[field_index]));
+ OP_REQUIRES(
+ context,
+ proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
+ // Many TensorFlow types don't have corresponding proto types and the
+ // user will get an error if they are requested. It would be nice to
+ // allow conversions here, but tf.cast already exists so we don't
+ // duplicate the functionality.
+ errors::InvalidArgument("Unexpected output type for ",
+ fd->full_name(), ": ", fd->cpp_type(), " to ",
+ output_types[field_index]));
field_index++;
field_descs.push_back(fd);
@@ -726,10 +681,9 @@ class DecodeProtoOp : public OpKernel {
errors::InvalidArgument("format must be one of binary or text"));
is_binary_ = format == "binary";
- // Enable the initial protobuf sanitizer, which is much
- // more expensive than the decoder.
- // TODO(nix): Remove this once the fast decoder
- // has passed security review.
+ // Enable the initial protobuf sanitizer, which is much more expensive than
+ // the decoder.
+ // TODO(nix): Remove this once the fast decoder has passed security review.
OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
}
@@ -742,9 +696,9 @@ class DecodeProtoOp : public OpKernel {
int field_count = fields_.size();
- // Save the argument shape for later, then flatten the input
- // Tensor since we are working componentwise. We will restore
- // the same shape in the returned Tensor.
+ // Save the argument shape for later, then flatten the input Tensor since we
+ // are working componentwise. We will restore the same shape in the returned
+ // Tensor.
const TensorShape& shape_prefix = buf_tensor.shape();
TensorShape sizes_shape = shape_prefix;
@@ -752,8 +706,8 @@ class DecodeProtoOp : public OpKernel {
Tensor* sizes_tensor = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
- // This is used to allocate binary bufs if used. It serves only
- // to define memory ownership.
+ // This is used to allocate binary bufs if used. It serves only to define
+ // memory ownership.
std::vector<string> tmp_binary_bufs(message_count);
// These are the actual buffers to use, which may be in tmp_binary_bufs
@@ -768,8 +722,8 @@ class DecodeProtoOp : public OpKernel {
bufs.push_back(buf);
}
} else {
- // We will have to allocate a copy, either to convert from text to
- // binary or to sanitize a binary proto.
+ // We will have to allocate a copy, either to convert from text to binary
+ // or to sanitize a binary proto.
for (int mi = 0; mi < message_count; ++mi) {
ReserializeMessage(ctx, buf_tensor.flat<string>()(mi),
&tmp_binary_bufs[mi]);
@@ -780,16 +734,14 @@ class DecodeProtoOp : public OpKernel {
}
}
- // Walk through all the strings in the input tensor, counting
- // the number of fields in each.
- // We can't allocate our actual output Tensor until we know the
- // maximum repeat count, so we do a first pass through the serialized
- // proto just counting fields.
- // We always allocate at least one value so that optional fields
- // are populated with default values - this avoids a TF
- // conditional when handling the output data.
- // The caller can distinguish between real data and defaults
- // using the repeat count matrix that is returned by decode_proto.
+ // Walk through all the strings in the input tensor, counting the number of
+ // fields in each. We can't allocate our actual output Tensor until we know
+ // the maximum repeat count, so we do a first pass through the serialized
+ // proto just counting fields. We always allocate at least one value so that
+ // optional fields are populated with default values - this avoids a TF
+ // conditional when handling the output data. The caller can distinguish
+ // between real data and defaults using the repeat count matrix that is
+ // returned by decode_proto.
std::vector<int32> max_sizes(field_count, 1);
for (int mi = 0; mi < message_count; ++mi) {
CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
@@ -814,14 +766,12 @@ class DecodeProtoOp : public OpKernel {
// REGISTER_OP(...)
// .Attr("output_types: list(type) >= 0")
// .Output("values: output_types")
- OP_REQUIRES_OK(ctx,
- // ctx->allocate_output(output_indices_[fi] + 1,
- ctx->allocate_output(fields_[fi]->output_index + 1,
- out_shape, &outputs[fi]));
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
+ out_shape, &outputs[fi]));
}
- // Make the second pass through the serialized proto, decoding
- // into preallocated tensors.
+ // Make the second pass through the serialized proto, decoding into
+ // preallocated tensors.
AccumulateFields(ctx, bufs, outputs);
}
@@ -976,6 +926,7 @@ class DecodeProtoOp : public OpKernel {
// Look up the FieldDescriptor for a particular field number.
bool LookupField(int field_number, int* field_index) {
// Look up the FieldDescriptor using linear search.
+ //
// TODO(nix): this could be sped up with binary search, but we are
// already way off the fastpath at this point. If you see a hotspot
// here, somebody is sending you very inefficient protos.
@@ -1010,6 +961,7 @@ class DecodeProtoOp : public OpKernel {
// This takes advantage of the sorted field numbers in most serialized
// protos: it tries the next expected field first rather than doing
// a lookup by field number.
+ //
// TODO(nix): haberman@ suggests a hybrid approach with a lookup table
// for small field numbers and a hash table for larger ones. This would
// be a simpler approach that should offer comparable speed in most
@@ -1029,9 +981,9 @@ class DecodeProtoOp : public OpKernel {
last_good_field_index = field_index;
}
} else {
- // If we see a field that is past the next field we want,
- // it was empty. Look for the one after that.
- // Repeat until we run out of fields that we care about.
+ // If we see a field that is past the next field we want, it was
+ // empty. Look for the one after that. Repeat until we run out of
+ // fields that we care about.
while (field_number >= next_good_field_number) {
if (field_number == next_good_field_number) {
last_good_field_number = field_number;
@@ -1044,10 +996,9 @@ class DecodeProtoOp : public OpKernel {
next_good_field_number =
fields_[last_good_field_index + 1]->number;
} else {
- // Saw something past the last field we care about.
- // Continue parsing the message just in case there
- // are disordered fields later, but any remaining
- // ordered fields will have no effect.
+ // Saw something past the last field we care about. Continue
+ // parsing the message just in case there are disordered fields
+ // later, but any remaining ordered fields will have no effect.
next_good_field_number = INT_MAX;
}
}
@@ -1077,20 +1028,20 @@ class DecodeProtoOp : public OpKernel {
WireFormatLite::WireType wire_type,
CodedInputStream* input, CollectorClass* collector) {
// The wire format library defines the same constants used in
- // descriptor.proto. This static_cast is safe because they
- // are guaranteed to stay in sync.
- // We need the field type from the FieldDescriptor here
- // because the wire format doesn't tell us anything about
- // what happens inside a packed repeated field: there is
- // enough information in the wire format to skip the
- // whole field but not enough to know how to parse what's
- // inside. For that we go to the schema.
+ // descriptor.proto. This static_cast is safe because they are guaranteed to
+ // stay in sync.
+ //
+ // We need the field type from the FieldDescriptor here because the wire
+ // format doesn't tell us anything about what happens inside a packed
+ // repeated field: there is enough information in the wire format to skip
+ // the whole field but not enough to know how to parse what's inside. For
+ // that we go to the schema.
WireFormatLite::WireType schema_wire_type =
WireFormatLite::WireTypeForFieldType(field.type);
- // Handle packed repeated fields. SkipField would skip the
- // whole length-delimited blob without letting us count the
- // values, so we have to scan them ourselves.
+ // Handle packed repeated fields. SkipField would skip the whole
+ // length-delimited blob without letting us count the values, so we have to
+ // scan them ourselves.
if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
// Handle packed repeated primitives.
@@ -1098,11 +1049,7 @@ class DecodeProtoOp : public OpKernel {
if (!input->ReadVarintSizeAsInt(&length)) {
return errors::DataLoss("CollectField: Failed reading packed size");
}
- Status st = collector->ReadPackedValues(input, field, length);
- if (!st.ok()) {
- return st;
- }
- return Status::OK();
+ return collector->ReadPackedValues(input, field, length);
}
// Read ordinary values, including strings, bytes, and messages.
@@ -1118,9 +1065,9 @@ class DecodeProtoOp : public OpKernel {
}
string message_type_;
- // Note that fields are sorted by increasing field number,
- // which is not in general the order given by the user-specified
- // field_names and output_types Op attributes.
+ // Note that fields are sorted by increasing field number, which is not in
+ // general the order given by the user-specified field_names and output_types
+ // Op attributes.
std::vector<std::unique_ptr<const FieldInfo>> fields_;
// Owned_desc_pool_ is null when using descriptor_source=local.
@@ -1131,12 +1078,12 @@ class DecodeProtoOp : public OpKernel {
// True if decoding binary format, false if decoding text format.
bool is_binary_;
- // True if the protos should be sanitized before parsing.
- // Enables the initial protobuf sanitizer, which is much
- // more expensive than the decoder. The flag defaults to true
- // but can be set to false for trusted sources.
- // TODO(nix): flip the default to false when the fast decoder
- // has passed security review.
+ // True if the protos should be sanitized before parsing. Enables the initial
+ // protobuf sanitizer, which is much more expensive than the decoder. The flag
+ // defaults to true but can be set to false for trusted sources.
+ //
+ // TODO(nix): Flip the default to false when the fast decoder has passed
+ // security review.
bool sanitize_;
TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);