diff options
Diffstat (limited to 'tensorflow/core/util/proto/decode.h')
-rw-r--r-- | tensorflow/core/util/proto/decode.h | 298 |
1 files changed, 199 insertions, 99 deletions
diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h index 74634a356a..cbcb203ee7 100644 --- a/tensorflow/core/util/proto/decode.h +++ b/tensorflow/core/util/proto/decode.h @@ -27,6 +27,7 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -103,6 +104,16 @@ template <class TensorType, enum WireFormatLite::FieldType DeclaredType> const uint8* ReadFromArray(const uint8* buf, TensorType* value); template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>( + const uint8* buf, int64* value) { + uint32 temp; + bool unused_ok; // The Counting pass would have failed if this were corrupt. + buf = ReadVarint32FromArray(buf, &unused_ok, &temp); + *value = static_cast<int64>(temp); + return buf; +} + +template <> inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>( const uint8* buf, int32* value) { uint32 temp; @@ -123,8 +134,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>( } template <> -inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>( - const uint8* buf, int64* value) { +inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>( + const uint8* buf, uint64* value) { uint32 temp; bool unused_ok; // The Counting pass would have failed if this were corrupt. buf = ReadVarint32FromArray(buf, &unused_ok, &temp); @@ -133,22 +144,26 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>( } template <> -inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>( - const uint8* buf, int32* value) { - uint32 temp; +inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>( + const uint8* buf, uint32* value) { bool unused_ok; // The Counting pass would have failed if this were corrupt. - buf = ReadVarint32FromArray(buf, &unused_ok, &temp); - *value = WrapUnsignedAsSigned32(temp); - return buf; + return ReadVarint32FromArray(buf, &unused_ok, value); +} + +template <> +inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>( + const uint8* buf, uint64* value) { + bool unused_ok; // The Counting pass would have failed if this were corrupt. + return ReadVarint64FromArray(buf, &unused_ok, value); } template <> -inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>( +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>( const uint8* buf, int64* value) { uint64 temp; bool unused_ok; // The Counting pass would have failed if this were corrupt. buf = ReadVarint64FromArray(buf, &unused_ok, &temp); - *value = static_cast<int64>(temp); + *value = WireFormatLite::ZigZagDecode32(temp); return buf; } @@ -173,8 +188,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>( } template <> -inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>( - const uint8* buf, int64* value) { +inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>( + const uint8* buf, uint64* value) { uint32 temp; buf = WireFormatLite::ReadPrimitiveFromArray<uint32, WireFormatLite::TYPE_FIXED32>( @@ -184,8 +199,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>( } template <> -inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>( - const uint8* buf, int32* value) { +inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>( + const uint8* buf, uint32* value) { uint32 temp; buf = WireFormatLite::ReadPrimitiveFromArray<uint32, WireFormatLite::TYPE_FIXED32>( @@ -195,8 +210,8 @@ inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>( } template <> -inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>( - const uint8* buf, int64* value) { +inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>( + const uint8* buf, uint64* value) { protobuf_uint64 temp; buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64, WireFormatLite::TYPE_FIXED64>( @@ -206,6 +221,17 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>( } template <> +inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>( + const uint8* buf, int64* value) { + int32 temp; + buf = WireFormatLite::ReadPrimitiveFromArray<int32, + WireFormatLite::TYPE_SFIXED32>( + buf, &temp); + *value = temp; + return buf; +} + +template <> inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>( const uint8* buf, int32* value) { return WireFormatLite::ReadPrimitiveFromArray<int32, @@ -233,6 +259,17 @@ inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>( } template <> +inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>( + const uint8* buf, double* value) { + float temp; + buf = + WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>( + buf, &temp); + *value = temp; + return buf; +} + +template <> inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>( const uint8* buf, double* value) { return WireFormatLite::ReadPrimitiveFromArray<double, @@ -334,48 +371,56 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number, inline Status ReadValue(CodedInputStream* input, WireFormatLite::FieldType field_type, int field_number, DataType dtype, int index, void* datap) { - // Dispatch to the appropriately typed field reader based on the - // schema type. + // Dispatch to the appropriately typed field reader based on the schema type. switch (field_type) { case WireFormatLite::TYPE_DOUBLE: return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>( input, index, datap); case WireFormatLite::TYPE_FLOAT: - if (dtype == DataType::DT_FLOAT) { - return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>( - input, index, datap); - } - if (dtype == DataType::DT_DOUBLE) { - return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>( - input, index, datap); + switch (dtype) { + case DataType::DT_DOUBLE: + return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>( + input, index, datap); + case DataType::DT_FLOAT: + return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_FLOAT for ", + DataTypeString(dtype)); } - // Any case that reaches this point should have triggered an error - // already. - return errors::DataLoss("Failed reading TYPE_FLOAT"); case WireFormatLite::TYPE_INT64: return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>( input, index, datap); case WireFormatLite::TYPE_UINT64: - return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>( - input, index, datap); + return ReadPrimitive<protobuf_uint64, uint64, + WireFormatLite::TYPE_UINT64>(input, index, datap); case WireFormatLite::TYPE_INT32: - return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>( - input, index, datap); + switch (dtype) { + case DataType::DT_INT64: + return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>( + input, index, datap); + case DataType::DT_INT32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_INT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_FIXED64: - return ReadPrimitive<protobuf_uint64, int64, + return ReadPrimitive<protobuf_uint64, uint64, WireFormatLite::TYPE_FIXED64>(input, index, datap); case WireFormatLite::TYPE_FIXED32: - if (dtype == DataType::DT_INT64) { - return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>( - input, index, datap); - } - if (dtype == DataType::DT_INT32) { - return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>( - input, index, datap); + switch (dtype) { + case DataType::DT_UINT64: + return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>( + input, index, datap); + case DataType::DT_UINT32: + return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_FIXED32 for ", + DataTypeString(dtype)); } - // Any case that reaches this point should have triggered an error - // already. - return errors::DataLoss("Failed reading TYPE_FIXED32"); case WireFormatLite::TYPE_BOOL: return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index, datap); @@ -388,29 +433,47 @@ inline Status ReadValue(CodedInputStream* input, case WireFormatLite::TYPE_BYTES: return ReadBytes(input, index, datap); case WireFormatLite::TYPE_UINT32: - if (dtype == DataType::DT_INT64) { - return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>( - input, index, datap); + switch (dtype) { + case DataType::DT_UINT64: + return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>( + input, index, datap); + case DataType::DT_UINT32: + return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_UINT32 for ", + DataTypeString(dtype)); } - if (dtype == DataType::DT_INT32) { - return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>( - input, index, datap); - } - // Any case that reaches this point should have triggered an error - // already. - return errors::DataLoss("Failed reading TYPE_UINT32"); case WireFormatLite::TYPE_ENUM: return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>( input, index, datap); case WireFormatLite::TYPE_SFIXED32: - return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>( - input, index, datap); + switch (dtype) { + case DataType::DT_INT64: + return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>( + input, index, datap); + case DataType::DT_INT32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_SFIXED32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SFIXED64: return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SFIXED64>(input, index, datap); case WireFormatLite::TYPE_SINT32: - return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>( - input, index, datap); + switch (dtype) { + case DataType::DT_INT64: + return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>( + input, index, datap); + case DataType::DT_INT32: + return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>( + input, index, datap); + default: + return errors::DataLoss("Failed reading TYPE_SINT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SINT64: return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>( input, index, datap); @@ -425,47 +488,66 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size, const WireFormatLite::FieldType field_type, const int field_number, const DataType dtype, const int stride, int* index, void* data) { - // Dispatch to the appropriately typed field reader based on the - // schema type. + // Dispatch to the appropriately typed field reader based on the schema type. switch (field_type) { case WireFormatLite::TYPE_DOUBLE: *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_FLOAT: - *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>( - buf, buf_size, *index, stride, data); - return Status::OK(); + switch (dtype) { + case DataType::DT_DOUBLE: + *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_FLOAT: + *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_FLOAT for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_INT64: *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_UINT64: - *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>( + *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_INT32: - *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>( - buf, buf_size, *index, stride, data); - return Status::OK(); + switch (dtype) { + case DataType::DT_INT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_INT32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_INT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_FIXED64: - *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>( + *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_FIXED32: - if (dtype == DataType::DT_INT64) { - *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>( - buf, buf_size, *index, stride, data); - return Status::OK(); - } - if (dtype == DataType::DT_INT32) { - *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>( - buf, buf_size, *index, stride, data); - return Status::OK(); + switch (dtype) { + case DataType::DT_UINT64: + *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_UINT32: + *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_FIXED32 for ", + DataTypeString(dtype)); } - // Any case that reaches this point should have triggered an error - // already. - return errors::DataLoss("Failed reading TYPE_FIXED32"); case WireFormatLite::TYPE_BOOL: *index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>( buf, buf_size, *index, stride, data); @@ -476,38 +558,56 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size, case WireFormatLite::TYPE_BYTES: return errors::DataLoss("Non-primitive type encountered as packed"); case WireFormatLite::TYPE_UINT32: - if (dtype == DataType::DT_INT64) { - *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>( - buf, buf_size, *index, stride, data); - return Status::OK(); + switch (dtype) { + case DataType::DT_UINT64: + *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_UINT32: + *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_UINT32 for ", + DataTypeString(dtype)); } - if (dtype == DataType::DT_INT32) { - *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>( - buf, buf_size, *index, stride, data); - return Status::OK(); - } - // Any case that reaches this point should have triggered an error - // already. - return errors::DataLoss("Failed reading TYPE_UINT32"); case WireFormatLite::TYPE_ENUM: *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_SFIXED32: - *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>( - buf, buf_size, *index, stride, data); - return Status::OK(); - + switch (dtype) { + case DataType::DT_INT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_INT32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_INT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SFIXED64: *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>( buf, buf_size, *index, stride, data); return Status::OK(); case WireFormatLite::TYPE_SINT32: - *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>( - buf, buf_size, *index, stride, data); - return Status::OK(); - + switch (dtype) { + case DataType::DT_INT64: + *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + case DataType::DT_INT32: + *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>( + buf, buf_size, *index, stride, data); + return Status::OK(); + default: + return errors::DataLoss("Failed reading TYPE_SINT32 for ", + DataTypeString(dtype)); + } case WireFormatLite::TYPE_SINT64: *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>( buf, buf_size, *index, stride, data); |