diff options
author | 2017-07-21 10:18:08 -0700 | |
---|---|---|
committer | 2017-07-21 10:22:30 -0700 | |
commit | fce5222dcf4d563a23055efeee89599dc36539c0 (patch) | |
tree | 931f01a084880e748eb1feb89fdf1bd2fe6b25d4 /tensorflow/core/framework/variant.cc | |
parent | 8e059376b5d9bbce159d6de2af5307210209ff5c (diff) |
Added DT_VARIANT type.
A tensor with DT_VARIANT type can store arbitrary C++ data structures.
DT_VARIANT is implemented using a type-erased data structure similar to
std::any, but with extensions to make it compatible with tensorflow::Tensor.
In particular, Encode and Decode methods need to be provided by C++ classes
whose objects are stored in Variant.
PiperOrigin-RevId: 162754827
Diffstat (limited to 'tensorflow/core/framework/variant.cc')
-rw-r--r-- | tensorflow/core/framework/variant.cc | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc new file mode 100644 index 0000000000..9f03a60892 --- /dev/null +++ b/tensorflow/core/framework/variant.cc @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +template <> +void* Variant::get() { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +template <> +const void* Variant::get() const { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { + proto->set_type_name(type_name); + proto->set_metadata(metadata); + proto->clear_tensors(); + for (int i = 0; i < tensors.size(); ++i) { + tensors[i].AsProtoField(proto->mutable_tensors()->Add()); + } +} + +bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { + type_name = proto.type_name(); + metadata = proto.metadata(); + tensors.clear(); + for (int i = 0; i < proto.tensors_size(); ++i) { + Tensor tmp; + if (!tmp.FromProto(proto.tensors(i))) return false; + tensors.push_back(tmp); + } + return true; +} + +template <> +string TypeNameVariant(const VariantTensorDataProto& value) { + return value.GetTypeName(); +} + +template <> +void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { + data->FromProto(value); +} + +template <> +bool DecodeVariant(const VariantTensorData& data, + VariantTensorDataProto* value) { + data.ToProto(value); + return true; +} + +template <> +void EncodeVariant(const VariantTensorDataProto& value, string* buf) { + value.SerializeToString(buf); +} + +template <> +bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { + return value->ParseFromString(buf); +} + +} // end namespace tensorflow |