/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ #include #include #include #include #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { // Type used for tag-dispatch of the Encode/Decode Variant implementations. This // template can determine whether the first type parameter `T` is one of the // following: // // * A POD type (TypeResolver) // * A tensorflow::Tensor (TypeResolver) // * A protocol buffer (TypeResolver) // * None of the above (TypeResolver) // template ::type>::value, bool = std::is_same::type, ::tensorflow::Tensor>::value, bool = std::is_base_of::type>::value> struct TypeResolver {}; // Specialization for POD type template void EncodeVariantImpl(const T& value, TypeResolver, VariantTensorData* data) { data->set_metadata(value); } // Specialization for tensorflow::Tensor template void EncodeVariantImpl(const T& value, TypeResolver, VariantTensorData* data) { data->tensors_.clear(); data->tensors_.push_back(value); } // Specialization for protobuf template void EncodeVariantImpl(const T& value, TypeResolver, VariantTensorData* data) { value.SerializeToString(&data->metadata_); } // Specialization for other types template void EncodeVariantImpl(const T& value, TypeResolver, VariantTensorData* data) { value.Encode(data); } // Specialization for POD type template bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { return data.get_metadata(value); } // Specialization for tensorflow::Tensor template bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { *value = data.tensors(0); return true; } // Specialization for protobuf template bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { string metadata; data.get_metadata(&metadata); return value->ParseFromString(std::move(metadata)); } // Specialization for other types template bool DecodeVariantImpl(VariantTensorData data, TypeResolver, T* value) { return value->Decode(std::move(data)); } template struct has_type_name : std::false_type {}; template struct has_type_name< C, typename std::enable_if().TypeName()), string>::value>::type> : std::true_type {}; template ::type>::value, bool = std::is_same::type, ::tensorflow::Tensor>::value, bool = std::is_base_of::type>::value> struct TypeNameResolver {}; template string TypeNameVariantImpl(const T& value, TypeNameResolver) { return value.TypeName(); } template string TypeNameVariantImpl( const T& value, TypeNameResolver) { return "tensorflow::Tensor"; } template string TypeNameVariantImpl( const T& value, TypeNameResolver) { return value.GetTypeName(); } template string TypeNameVariantImpl( const T& value, TypeNameResolver) { return port::MaybeAbiDemangle(MakeTypeIndex().name()); } template string TypeNameVariant(const T& value) { return TypeNameVariantImpl(value, TypeNameResolver()); } template struct has_debug_string : std::false_type {}; template struct has_debug_string< C, typename std::enable_if().DebugString()), string>::value>::type> : std::true_type {}; template struct can_strcat : std::false_type {}; template struct can_strcat< C, typename std::enable_if())), string>::value>::type> : std::true_type {}; template ::type>::value, bool = can_strcat::type>::value> struct DebugStringResolver {}; // TODO(ebrevdo): Expand DebugStringResolver to return TypeString if // there is no StrCat() constructor. template string DebugStringVariantImpl( const T& value, DebugStringResolver) { return value.DebugString(); } template string DebugStringVariantImpl( const T& value, DebugStringResolver) { return strings::StrCat(value); } template string DebugStringVariantImpl( const T& value, DebugStringResolver) { return "?"; } template string DebugStringVariant(const T& value) { return DebugStringVariantImpl(value, DebugStringResolver()); } template void EncodeVariant(const T& value, VariantTensorData* data) { EncodeVariantImpl(value, TypeResolver(), data); data->set_type_name(TypeNameVariant(value)); } template bool DecodeVariant(VariantTensorData* data, T* value) { return DecodeVariantImpl(std::move(*data), TypeResolver(), value); } template void EncodeVariant(const T& value, string* buf) { VariantTensorData data; EncodeVariantImpl(value, TypeResolver(), &data); data.set_type_name(TypeNameVariant(value)); DCHECK(buf != nullptr); data.SerializeToString(buf); } template bool DecodeVariant(string* buf, T* value) { VariantTensorData data; if (!data.ParseFromString(*buf)) return false; if (!DecodeVariantImpl(std::move(data), TypeResolver(), value)) { return false; } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); template <> bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); template <> bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. void EncodeVariantList(const Variant* variant_array, int64 n, std::unique_ptr e); // Decodes an array of Variant objects from the given StringListDecoder. // `variant_array` is assumed to point to an array of `n` Variant objects. bool DecodeVariantList(std::unique_ptr d, Variant* variant_array, int64 n); } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_