aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant.cc
diff options
context:
space:
mode:
authorGravatar Manjunath Kudlur <keveman@google.com>2017-07-21 10:18:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 10:22:30 -0700
commitfce5222dcf4d563a23055efeee89599dc36539c0 (patch)
tree931f01a084880e748eb1feb89fdf1bd2fe6b25d4 /tensorflow/core/framework/variant.cc
parent8e059376b5d9bbce159d6de2af5307210209ff5c (diff)
Added DT_VARIANT type.
A tensor with DT_VARIANT type can store arbitrary C++ data structures. DT_VARIANT is implemented using a type-erased data structure similar to std::any, but with extensions to make it compatible with tensorflow::Tensor. In particular, Encode and Decode methods need to be provided by C++ classes whose objects are stored in Variant. PiperOrigin-RevId: 162754827
Diffstat (limited to 'tensorflow/core/framework/variant.cc')
-rw-r--r--tensorflow/core/framework/variant.cc87
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
new file mode 100644
index 0000000000..9f03a60892
--- /dev/null
+++ b/tensorflow/core/framework/variant.cc
@@ -0,0 +1,87 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+template <>
+void* Variant::get() {
+ if (is_empty()) {
+ return nullptr;
+ }
+ return value_->RawPtr();
+}
+
+template <>
+const void* Variant::get() const {
+ if (is_empty()) {
+ return nullptr;
+ }
+ return value_->RawPtr();
+}
+
+void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
+ proto->set_type_name(type_name);
+ proto->set_metadata(metadata);
+ proto->clear_tensors();
+ for (int i = 0; i < tensors.size(); ++i) {
+ tensors[i].AsProtoField(proto->mutable_tensors()->Add());
+ }
+}
+
+bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+ type_name = proto.type_name();
+ metadata = proto.metadata();
+ tensors.clear();
+ for (int i = 0; i < proto.tensors_size(); ++i) {
+ Tensor tmp;
+ if (!tmp.FromProto(proto.tensors(i))) return false;
+ tensors.push_back(tmp);
+ }
+ return true;
+}
+
+template <>
+string TypeNameVariant(const VariantTensorDataProto& value) {
+ return value.GetTypeName();
+}
+
+template <>
+void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) {
+ data->FromProto(value);
+}
+
+template <>
+bool DecodeVariant(const VariantTensorData& data,
+ VariantTensorDataProto* value) {
+ data.ToProto(value);
+ return true;
+}
+
+template <>
+void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
+ value.SerializeToString(buf);
+}
+
+template <>
+bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(buf);
+}
+
+} // end namespace tensorflow