diff options
Diffstat (limited to 'tensorflow/core/kernels/encode_proto_op.cc')
-rw-r--r-- | tensorflow/core/kernels/encode_proto_op.cc | 284 |
1 files changed, 157 insertions, 127 deletions
diff --git a/tensorflow/core/kernels/encode_proto_op.cc b/tensorflow/core/kernels/encode_proto_op.cc index 3b02ae52a2..4a0c1943e5 100644 --- a/tensorflow/core/kernels/encode_proto_op.cc +++ b/tensorflow/core/kernels/encode_proto_op.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/proto/descriptors.h" +#include "tensorflow/core/util/proto/proto_utils.h" namespace tensorflow { namespace { @@ -42,9 +43,9 @@ using ::tensorflow::protobuf::internal::WireFormatLite; using ::tensorflow::protobuf::io::CodedOutputStream; using ::tensorflow::protobuf::io::StringOutputStream; -// Computes the total serialized size for a packed repeated field. -// For fixed-size types this can just multiply, but for variable-sized -// types it has to iterate through the values in the tensor. +// Computes the total serialized size for a packed repeated field. For +// fixed-size types this can just multiply, but for variable-sized types it has +// to iterate through the values in the tensor. template <WireFormatLite::FieldType FieldType, typename TensorT> size_t TotalPackedSize(const Tensor& input, int message_index, int size); @@ -83,11 +84,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64>(const Tensor& input, } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input, + int message_index, + int size) { size_t data_size = 0; - auto input_t = input.flat_inner_dims<int64>(); + auto input_t = input.flat_inner_dims<uint64>(); for (int64 i = 0; i < size; i++) { data_size += WireFormatLite::UInt64Size( input_t(static_cast<int64>(message_index), i)); @@ -96,6 +97,19 @@ size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, int64>(const Tensor& input, } template <> +size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::Int32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input, int message_index, int size) { @@ -109,23 +123,20 @@ size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input, } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, int64>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>( + const Tensor& input, int message_index, int size) { return size * WireFormatLite::kFixed64Size; } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int64>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>( + const Tensor& input, int message_index, int size) { return size * WireFormatLite::kFixed32Size; } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, int32>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>( + const Tensor& input, int message_index, int size) { return size * WireFormatLite::kFixed32Size; } @@ -137,11 +148,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input, } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input, + int message_index, + int size) { size_t data_size = 0; - auto input_t = input.flat_inner_dims<int64>(); + auto input_t = input.flat_inner_dims<uint64>(); for (int64 i = 0; i < size; i++) { data_size += WireFormatLite::UInt32Size( input_t(static_cast<int64>(message_index), i)); @@ -150,11 +161,11 @@ size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int64>(const Tensor& input, } template <> -size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, int32>(const Tensor& input, - int message_index, - int size) { +size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input, + int message_index, + int size) { size_t data_size = 0; - auto input_t = input.flat_inner_dims<int32>(); + auto input_t = input.flat_inner_dims<uint32>(); for (int64 i = 0; i < size; i++) { data_size += WireFormatLite::UInt32Size( input_t(static_cast<int64>(message_index), i)); @@ -182,6 +193,12 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>( } template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64>( + const Tensor& input, int message_index, int size) { + return size * WireFormatLite::kSFixed32Size; +} + +template <> size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64>( const Tensor& input, int message_index, int size) { return size * WireFormatLite::kSFixed64Size; @@ -201,6 +218,19 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input, } template <> +size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64>(const Tensor& input, + int message_index, + int size) { + size_t data_size = 0; + auto input_t = input.flat_inner_dims<int64>(); + for (int64 i = 0; i < size; i++) { + data_size += WireFormatLite::SInt32Size( + input_t(static_cast<int64>(message_index), i)); + } + return data_size; +} + +template <> size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input, int message_index, int size) { @@ -213,14 +243,13 @@ size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64>(const Tensor& input, return data_size; } -// Writes a possibly repeated primitive field. -// TensorFlow does not have unsigned types, so we decode them to signed and -// encode them back to unsigned. +// Writes a possibly repeated primitive field. TensorFlow does not have unsigned +// types, so we decode them to signed and encode them back to unsigned. template <typename TensorT, typename ProtoT, WireFormatLite::FieldType FieldType, void Writer(ProtoT, CodedOutputStream*)> -void WriteField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { auto wire_type = WireFormatLite::WireTypeForFieldType( WireFormatLite::FieldType(field_desc.type())); @@ -250,12 +279,14 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input, Writer(value, output); } } + return Status::OK(); } // Writes a possibly repeated string, bytes, or message field. template <typename T, void Writer(int, const T&, CodedOutputStream*)> -void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, + CodedOutputStream* output) { auto input_t = input.flat_inner_dims<T>(); for (int64 i = 0; i < size; i++) { const T& value = input_t(static_cast<int64>(message_index), i); @@ -264,14 +295,14 @@ void WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, // small speedup. Writer(field_desc.number(), value, output); } + return Status::OK(); } -// Writes a group field. -// Groups are treated like submessages, but tag-delimited -// instead of length-delimited. WireFormatLite handles this -// differently so we code it ourselves. -void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { +// Writes a group field. Groups are treated like submessages, but tag-delimited +// instead of length-delimited. WireFormatLite handles this differently so we +// code it ourselves. +Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { auto input_t = input.flat_inner_dims<string>(); for (int64 i = 0; i < size; i++) { const string& value = input_t(static_cast<int64>(message_index), i); @@ -282,16 +313,16 @@ void WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, WireFormatLite::WriteTag(field_desc.number(), WireFormatLite::WIRETYPE_END_GROUP, output); } + return Status::OK(); } -// Writes a (possibly repeated) field into an output stream. -// It is the caller's responsibility to ensure that the type of -// the input tensor is compatible with the type of the proto -// field descriptor, and that (message_index, size-1) is within -// bounds. -void WriteField(const FieldDescriptor& field_desc, const Tensor& input, - int message_index, int size, CodedOutputStream* output) { - DataType tf_type = input.dtype(); +// Writes a (possibly repeated) field into an output stream. It is the caller's +// responsibility to ensure that the type of the input tensor is compatible with +// the type of the proto field descriptor, and that (message_index, size-1) is +// within bounds. +Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, + int message_index, int size, CodedOutputStream* output) { + DataType dtype = input.dtype(); switch (field_desc.type()) { case WireFormatLite::TYPE_DOUBLE: @@ -299,7 +330,7 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input, WireFormatLite::WriteDoubleNoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_FLOAT: - switch (tf_type) { + switch (dtype) { case DataType::DT_FLOAT: return WriteField<float, float, WireFormatLite::TYPE_FLOAT, WireFormatLite::WriteFloatNoTag>( @@ -309,36 +340,48 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input, WireFormatLite::WriteFloatNoTag>( field_desc, input, message_index, size, output); default: - return; + return errors::DataLoss("Failed writing TYPE_FLOAT for ", + DataTypeString(dtype)); } case WireFormatLite::TYPE_INT64: return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_INT64, WireFormatLite::WriteInt64NoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_UINT64: - return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_UINT64, + return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64, WireFormatLite::WriteUInt64NoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_INT32: - return WriteField<int32, int32, WireFormatLite::TYPE_INT32, - WireFormatLite::WriteInt32NoTag>( - field_desc, input, message_index, size, output); + switch (dtype) { + case DataType::DT_INT64: + return WriteField<int64, int32, WireFormatLite::TYPE_INT32, + WireFormatLite::WriteInt32NoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_INT32: + return WriteField<int32, int32, WireFormatLite::TYPE_INT32, + WireFormatLite::WriteInt32NoTag>( + field_desc, input, message_index, size, output); + default: + return errors::DataLoss("Failed writing TYPE_INT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_FIXED64: - return WriteField<int64, protobuf_uint64, WireFormatLite::TYPE_FIXED64, + return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64, WireFormatLite::WriteFixed64NoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_FIXED32: - switch (tf_type) { - case DataType::DT_INT64: - return WriteField<int64, uint32, WireFormatLite::TYPE_FIXED32, + switch (dtype) { + case DataType::DT_UINT64: + return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32, WireFormatLite::WriteFixed32NoTag>( field_desc, input, message_index, size, output); - case DataType::DT_INT32: - return WriteField<int32, uint32, WireFormatLite::TYPE_FIXED32, + case DataType::DT_UINT32: + return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32, WireFormatLite::WriteFixed32NoTag>( field_desc, input, message_index, size, output); default: - return; + return errors::DataLoss("Failed writing TYPE_FIXED32 for ", + DataTypeString(dtype)); } case WireFormatLite::TYPE_BOOL: return WriteField<bool, bool, WireFormatLite::TYPE_BOOL, @@ -356,34 +399,55 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input, return WriteVarLenField<string, WireFormatLite::WriteBytes>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_UINT32: - switch (tf_type) { - case DataType::DT_INT64: - return WriteField<int64, uint32, WireFormatLite::TYPE_UINT32, + switch (dtype) { + case DataType::DT_UINT64: + return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32, WireFormatLite::WriteUInt32NoTag>( field_desc, input, message_index, size, output); - case DataType::DT_INT32: - return WriteField<int32, uint32, WireFormatLite::TYPE_UINT32, + case DataType::DT_UINT32: + return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32, WireFormatLite::WriteUInt32NoTag>( field_desc, input, message_index, size, output); default: - return; + return errors::DataLoss("Failed writing TYPE_UINT32 for ", + DataTypeString(dtype)); } case WireFormatLite::TYPE_ENUM: return WriteField<int32, int32, WireFormatLite::TYPE_ENUM, WireFormatLite::WriteEnumNoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_SFIXED32: - return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32, - WireFormatLite::WriteSFixed32NoTag>( - field_desc, input, message_index, size, output); + switch (dtype) { + case DataType::DT_INT64: + return WriteField<int64, int32, WireFormatLite::TYPE_SFIXED32, + WireFormatLite::WriteSFixed32NoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_INT32: + return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32, + WireFormatLite::WriteSFixed32NoTag>( + field_desc, input, message_index, size, output); + default: + return errors::DataLoss("Failed writing TYPE_SFIXED32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SFIXED64: return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SFIXED64, WireFormatLite::WriteSFixed64NoTag>( field_desc, input, message_index, size, output); case WireFormatLite::TYPE_SINT32: - return WriteField<int32, int32, WireFormatLite::TYPE_SINT32, - WireFormatLite::WriteSInt32NoTag>( - field_desc, input, message_index, size, output); + switch (dtype) { + case DataType::DT_INT64: + return WriteField<int64, int32, WireFormatLite::TYPE_SINT32, + WireFormatLite::WriteSInt32NoTag>( + field_desc, input, message_index, size, output); + case DataType::DT_INT32: + return WriteField<int32, int32, WireFormatLite::TYPE_SINT32, + WireFormatLite::WriteSInt32NoTag>( + field_desc, input, message_index, size, output); + default: + return errors::DataLoss("Failed writing TYPE_SINT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SINT64: return WriteField<int64, protobuf_int64, WireFormatLite::TYPE_SINT64, WireFormatLite::WriteSInt64NoTag>( @@ -392,42 +456,6 @@ void WriteField(const FieldDescriptor& field_desc, const Tensor& input, } } -// Checks that a Protobuf field is compatible with a TensorFlow datatype. -// This is separated from WriteField to lift it out of the inner loop. -bool IsCompatibleType(const FieldDescriptor& field_desc, DataType tf_type) { - switch (field_desc.type()) { - case WireFormatLite::TYPE_DOUBLE: - return tf_type == DataType::DT_DOUBLE; - case WireFormatLite::TYPE_FLOAT: - return tf_type == DataType::DT_FLOAT || tf_type == DataType::DT_DOUBLE; - case WireFormatLite::TYPE_INT64: - case WireFormatLite::TYPE_SFIXED64: - case WireFormatLite::TYPE_SINT64: - return tf_type == DataType::DT_INT64; - case WireFormatLite::TYPE_UINT64: - return tf_type == DataType::DT_INT64; - case WireFormatLite::TYPE_INT32: - case WireFormatLite::TYPE_ENUM: - case WireFormatLite::TYPE_SFIXED32: - case WireFormatLite::TYPE_SINT32: - return tf_type == DataType::DT_INT32; - case WireFormatLite::TYPE_FIXED64: - return tf_type == DataType::DT_INT64; - case WireFormatLite::TYPE_FIXED32: - case WireFormatLite::TYPE_UINT32: - return tf_type == DataType::DT_INT64 || tf_type == DataType::DT_INT32; - case WireFormatLite::TYPE_BOOL: - return tf_type == DataType::DT_BOOL; - case WireFormatLite::TYPE_STRING: - case WireFormatLite::TYPE_GROUP: - case WireFormatLite::TYPE_MESSAGE: - case WireFormatLite::TYPE_BYTES: - return tf_type == DataType::DT_STRING; - // default: intentionally omitted in order to enable static checking. - } - return false; -} - class EncodeProtoOp : public OpKernel { public: explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) { @@ -475,14 +503,14 @@ class EncodeProtoOp : public OpKernel { }); } - void Compute(OpKernelContext* cx) override { + void Compute(OpKernelContext* ctx) override { const Tensor* sizes_tensor; - OP_REQUIRES_OK(cx, cx->input("sizes", &sizes_tensor)); + OP_REQUIRES_OK(ctx, ctx->input("sizes", &sizes_tensor)); OpInputList values; - OP_REQUIRES_OK(cx, cx->input_list("values", &values)); + OP_REQUIRES_OK(ctx, ctx->input_list("values", &values)); - OP_REQUIRES(cx, field_descs_.size() == values.size(), + OP_REQUIRES(ctx, field_descs_.size() == values.size(), errors::InvalidArgument( "Length of inputs list must match field_names")); @@ -493,12 +521,14 @@ class EncodeProtoOp : public OpKernel { const Tensor& v = values[i]; // The type of each value tensor must match the corresponding field. - OP_REQUIRES(cx, IsCompatibleType(*field_descs_[i], v.dtype()), - errors::InvalidArgument( - "Incompatible type for field " + field_names_[i] + - ". Saw dtype: ", - DataTypeString(v.dtype()), - " but field type is: ", field_descs_[i]->type_name())); + OP_REQUIRES( + ctx, + proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()), + errors::InvalidArgument( + "Incompatible type for field " + field_names_[i] + + ". Saw dtype: ", + DataTypeString(v.dtype()), + " but field type is: ", field_descs_[i]->type_name())); // All value tensors must have the same shape prefix (i.e. batch size). TensorShape shape_prefix = v.shape(); @@ -507,14 +537,14 @@ class EncodeProtoOp : public OpKernel { // Do some initialization on the first input value. The rest will // have to match this one. if (i == 0) { - OP_REQUIRES(cx, v.dims() >= 1, + OP_REQUIRES(ctx, v.dims() >= 1, errors::InvalidArgument( "Expected value to be at least a vector, saw shape: ", v.shape().DebugString())); common_prefix = shape_prefix; message_count = common_prefix.num_elements(); } else { - OP_REQUIRES(cx, shape_prefix == common_prefix, + OP_REQUIRES(ctx, shape_prefix == common_prefix, errors::InvalidArgument( "Values must match up to the last dimension")); } @@ -523,7 +553,7 @@ class EncodeProtoOp : public OpKernel { TensorShape expected_sizes_shape = common_prefix; expected_sizes_shape.AddDim(field_descs_.size()); - OP_REQUIRES(cx, sizes_tensor->shape() == expected_sizes_shape, + OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape, errors::InvalidArgument( "sizes should be batch_size + [len(field_names)]. Saw: ", sizes_tensor->shape().DebugString(), @@ -536,12 +566,11 @@ class EncodeProtoOp : public OpKernel { int max_size = v.dim_size(v.dims() - 1); // The last dimension of a value tensor must be greater than the - // corresponding - // size in the sizes tensor. + // corresponding size in the sizes tensor. for (int message_index = 0; message_index < message_count; message_index++) { OP_REQUIRES( - cx, sizes(message_index, i) <= max_size, + ctx, sizes(message_index, i) <= max_size, errors::InvalidArgument( "Size to write must not be larger than value tensor; but saw: ", sizes(message_index, i), " > ", max_size, " at message ", @@ -551,13 +580,13 @@ class EncodeProtoOp : public OpKernel { // This pointer is owned by the context. Tensor* output_tensor; - OP_REQUIRES_OK(cx, cx->allocate_output(0, common_prefix, &output_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor)); auto bufs = output_tensor->flat<string>(); for (int message_index = 0; message_index < message_count; message_index++) { // TODO(nix): possibly optimize allocation here by calling - // bufs(message_index).reserve(DEFAULT_BUF_SIZE); + // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`. StringOutputStream output_string(&bufs(message_index)); CodedOutputStream out(&output_string); // Write fields in ascending field_number order. @@ -566,7 +595,8 @@ class EncodeProtoOp : public OpKernel { const Tensor& v = values[i]; int size = sizes(message_index, i); if (!size) continue; - WriteField(field_desc, v, message_index, size, &out); + OP_REQUIRES_OK(ctx, + WriteField(field_desc, v, message_index, size, &out)); } } } @@ -578,8 +608,8 @@ class EncodeProtoOp : public OpKernel { // Owned_desc_pool_ is null when using descriptor_source=local. std::unique_ptr<DescriptorPool> owned_desc_pool_; - // Contains indices into field_names_, sorted by field number since - // that's the order of writing. + // Contains indices into field_names_, sorted by field number since that's the + // order of writing. std::vector<int> sorted_field_index_; TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp); |