diff options
Diffstat (limited to 'tensorflow/core/framework/variant_encode_decode.h')
-rw-r--r-- | tensorflow/core/framework/variant_encode_decode.h | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f155aa4892..5e08e5a7a6 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" @@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value, // Specialization for POD type template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, true /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { @@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for tensorflow::Tensor template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, true /* Tensor */, false /* protobuf */>, T* value) { @@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for protobuf template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, true /* protobuf */>, T* value) { @@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for other types template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { - return value->Decode(data); + return value->Decode(std::move(data)); } template <typename C, typename = void> @@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) { } template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value) { - return DecodeVariantImpl(data, TypeResolver<T>(), value); +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value); } template <typename T> @@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) { } template <typename T> -bool DecodeVariant(const string& buf, T* value) { +bool DecodeVariant(string* buf, T* value) { VariantTensorData data; - if (!data.ParseFromString(buf)) return false; - if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) { + return false; + } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); + template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); + template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); + template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value); +bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. |