aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/util/proto/decode.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/util/proto/decode.h')
-rw-r--r--tensorflow/core/util/proto/decode.h298
1 files changed, 199 insertions, 99 deletions
diff --git a/tensorflow/core/util/proto/decode.h b/tensorflow/core/util/proto/decode.h
index 74634a356a..cbcb203ee7 100644
--- a/tensorflow/core/util/proto/decode.h
+++ b/tensorflow/core/util/proto/decode.h
@@ -27,6 +27,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_PROTO_DECODE_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -103,6 +104,16 @@ template <class TensorType, enum WireFormatLite::FieldType DeclaredType>
const uint8* ReadFromArray(const uint8* buf, TensorType* value);
template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT32>(
+ const uint8* buf, int64* value) {
+ uint32 temp;
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
+ *value = static_cast<int64>(temp);
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_INT32>(
const uint8* buf, int32* value) {
uint32 temp;
@@ -123,8 +134,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_INT64>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, uint64* value) {
uint32 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
@@ -133,22 +144,26 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT32>(
}
template <>
-inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_UINT32>(
- const uint8* buf, int32* value) {
- uint32 temp;
+inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_UINT32>(
+ const uint8* buf, uint32* value) {
bool unused_ok; // The Counting pass would have failed if this were corrupt.
- buf = ReadVarint32FromArray(buf, &unused_ok, &temp);
- *value = WrapUnsignedAsSigned32(temp);
- return buf;
+ return ReadVarint32FromArray(buf, &unused_ok, value);
+}
+
+template <>
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_UINT64>(
+ const uint8* buf, uint64* value) {
+ bool unused_ok; // The Counting pass would have failed if this were corrupt.
+ return ReadVarint64FromArray(buf, &unused_ok, value);
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_UINT64>(
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT32>(
const uint8* buf, int64* value) {
uint64 temp;
bool unused_ok; // The Counting pass would have failed if this were corrupt.
buf = ReadVarint64FromArray(buf, &unused_ok, &temp);
- *value = static_cast<int64>(temp);
+ *value = WireFormatLite::ZigZagDecode32(temp);
return buf;
}
@@ -173,8 +188,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SINT64>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, uint64* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
@@ -184,8 +199,8 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED32>(
}
template <>
-inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
- const uint8* buf, int32* value) {
+inline const uint8* ReadFromArray<uint32, WireFormatLite::TYPE_FIXED32>(
+ const uint8* buf, uint32* value) {
uint32 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<uint32,
WireFormatLite::TYPE_FIXED32>(
@@ -195,8 +210,8 @@ inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_FIXED32>(
}
template <>
-inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
- const uint8* buf, int64* value) {
+inline const uint8* ReadFromArray<uint64, WireFormatLite::TYPE_FIXED64>(
+ const uint8* buf, uint64* value) {
protobuf_uint64 temp;
buf = WireFormatLite::ReadPrimitiveFromArray<protobuf_uint64,
WireFormatLite::TYPE_FIXED64>(
@@ -206,6 +221,17 @@ inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_FIXED64>(
}
template <>
+inline const uint8* ReadFromArray<int64, WireFormatLite::TYPE_SFIXED32>(
+ const uint8* buf, int64* value) {
+ int32 temp;
+ buf = WireFormatLite::ReadPrimitiveFromArray<int32,
+ WireFormatLite::TYPE_SFIXED32>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<int32, WireFormatLite::TYPE_SFIXED32>(
const uint8* buf, int32* value) {
return WireFormatLite::ReadPrimitiveFromArray<int32,
@@ -233,6 +259,17 @@ inline const uint8* ReadFromArray<float, WireFormatLite::TYPE_FLOAT>(
}
template <>
+inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_FLOAT>(
+ const uint8* buf, double* value) {
+ float temp;
+ buf =
+ WireFormatLite::ReadPrimitiveFromArray<float, WireFormatLite::TYPE_FLOAT>(
+ buf, &temp);
+ *value = temp;
+ return buf;
+}
+
+template <>
inline const uint8* ReadFromArray<double, WireFormatLite::TYPE_DOUBLE>(
const uint8* buf, double* value) {
return WireFormatLite::ReadPrimitiveFromArray<double,
@@ -334,48 +371,56 @@ inline Status ReadGroupBytes(CodedInputStream* input, int field_number,
inline Status ReadValue(CodedInputStream* input,
WireFormatLite::FieldType field_type, int field_number,
DataType dtype, int index, void* datap) {
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
return ReadPrimitive<double, double, WireFormatLite::TYPE_DOUBLE>(
input, index, datap);
case WireFormatLite::TYPE_FLOAT:
- if (dtype == DataType::DT_FLOAT) {
- return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
- input, index, datap);
- }
- if (dtype == DataType::DT_DOUBLE) {
- return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_DOUBLE:
+ return ReadPrimitive<float, double, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ case DataType::DT_FLOAT:
+ return ReadPrimitive<float, float, WireFormatLite::TYPE_FLOAT>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_FLOAT for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FLOAT");
case WireFormatLite::TYPE_INT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_INT64>(
input, index, datap);
case WireFormatLite::TYPE_UINT64:
- return ReadPrimitive<protobuf_uint64, int64, WireFormatLite::TYPE_UINT64>(
- input, index, datap);
+ return ReadPrimitive<protobuf_uint64, uint64,
+ WireFormatLite::TYPE_UINT64>(input, index, datap);
case WireFormatLite::TYPE_INT32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_INT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_FIXED64:
- return ReadPrimitive<protobuf_uint64, int64,
+ return ReadPrimitive<protobuf_uint64, uint64,
WireFormatLite::TYPE_FIXED64>(input, index, datap);
case WireFormatLite::TYPE_FIXED32:
- if (dtype == DataType::DT_INT64) {
- return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_FIXED32>(
- input, index, datap);
- }
- if (dtype == DataType::DT_INT32) {
- return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_FIXED32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ case DataType::DT_UINT32:
+ return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_FIXED32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
return ReadPrimitive<bool, bool, WireFormatLite::TYPE_BOOL>(input, index,
datap);
@@ -388,29 +433,47 @@ inline Status ReadValue(CodedInputStream* input,
case WireFormatLite::TYPE_BYTES:
return ReadBytes(input, index, datap);
case WireFormatLite::TYPE_UINT32:
- if (dtype == DataType::DT_INT64) {
- return ReadPrimitive<uint32, int64, WireFormatLite::TYPE_UINT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ return ReadPrimitive<uint32, uint64, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ case DataType::DT_UINT32:
+ return ReadPrimitive<uint32, uint32, WireFormatLite::TYPE_UINT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_UINT32 for ",
+ DataTypeString(dtype));
}
- if (dtype == DataType::DT_INT32) {
- return ReadPrimitive<uint32, int32, WireFormatLite::TYPE_UINT32>(
- input, index, datap);
- }
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
return ReadPrimitive<int32, int32, WireFormatLite::TYPE_ENUM>(
input, index, datap);
case WireFormatLite::TYPE_SFIXED32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SFIXED32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_SFIXED32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SFIXED64:
return ReadPrimitive<protobuf_int64, int64,
WireFormatLite::TYPE_SFIXED64>(input, index, datap);
case WireFormatLite::TYPE_SINT32:
- return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
- input, index, datap);
+ switch (dtype) {
+ case DataType::DT_INT64:
+ return ReadPrimitive<int32, int64, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ case DataType::DT_INT32:
+ return ReadPrimitive<int32, int32, WireFormatLite::TYPE_SINT32>(
+ input, index, datap);
+ default:
+ return errors::DataLoss("Failed reading TYPE_SINT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SINT64:
return ReadPrimitive<protobuf_int64, int64, WireFormatLite::TYPE_SINT64>(
input, index, datap);
@@ -425,47 +488,66 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
const WireFormatLite::FieldType field_type,
const int field_number, const DataType dtype,
const int stride, int* index, void* data) {
- // Dispatch to the appropriately typed field reader based on the
- // schema type.
+ // Dispatch to the appropriately typed field reader based on the schema type.
switch (field_type) {
case WireFormatLite::TYPE_DOUBLE:
*index += ReadPackedPrimitives<double, WireFormatLite::TYPE_DOUBLE>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FLOAT:
- *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_DOUBLE:
+ *index += ReadPackedPrimitives<double, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_FLOAT:
+ *index += ReadPackedPrimitives<float, WireFormatLite::TYPE_FLOAT>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_FLOAT for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_INT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_UINT64:
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT64>(
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_INT32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_INT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_FIXED64:
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED64>(
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_FIXED32:
- if (dtype == DataType::DT_INT64) {
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_FIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
- }
- if (dtype == DataType::DT_INT32) {
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_FIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_UINT32:
+ *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_FIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_FIXED32 for ",
+ DataTypeString(dtype));
}
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_FIXED32");
case WireFormatLite::TYPE_BOOL:
*index += ReadPackedPrimitives<bool, WireFormatLite::TYPE_BOOL>(
buf, buf_size, *index, stride, data);
@@ -476,38 +558,56 @@ inline Status ReadPackedFromArray(const void* buf, size_t buf_size,
case WireFormatLite::TYPE_BYTES:
return errors::DataLoss("Non-primitive type encountered as packed");
case WireFormatLite::TYPE_UINT32:
- if (dtype == DataType::DT_INT64) {
- *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_UINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
+ switch (dtype) {
+ case DataType::DT_UINT64:
+ *index += ReadPackedPrimitives<uint64, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_UINT32:
+ *index += ReadPackedPrimitives<uint32, WireFormatLite::TYPE_UINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_UINT32 for ",
+ DataTypeString(dtype));
}
- if (dtype == DataType::DT_INT32) {
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_UINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
- }
- // Any case that reaches this point should have triggered an error
- // already.
- return errors::DataLoss("Failed reading TYPE_UINT32");
case WireFormatLite::TYPE_ENUM:
*index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_ENUM>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SFIXED32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
-
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SFIXED32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_INT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SFIXED64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SFIXED64>(
buf, buf_size, *index, stride, data);
return Status::OK();
case WireFormatLite::TYPE_SINT32:
- *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
- buf, buf_size, *index, stride, data);
- return Status::OK();
-
+ switch (dtype) {
+ case DataType::DT_INT64:
+ *index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ case DataType::DT_INT32:
+ *index += ReadPackedPrimitives<int32, WireFormatLite::TYPE_SINT32>(
+ buf, buf_size, *index, stride, data);
+ return Status::OK();
+ default:
+ return errors::DataLoss("Failed reading TYPE_SINT32 for ",
+ DataTypeString(dtype));
+ }
case WireFormatLite::TYPE_SINT64:
*index += ReadPackedPrimitives<int64, WireFormatLite::TYPE_SINT64>(
buf, buf_size, *index, stride, data);