aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/c/c_api.h1
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake7
-rw-r--r--tensorflow/core/BUILD22
-rw-r--r--tensorflow/core/framework/allocator.h19
-rw-r--r--tensorflow/core/framework/tensor.cc78
-rw-r--r--tensorflow/core/framework/tensor.proto13
-rw-r--r--tensorflow/core/framework/tensor_test.cc98
-rw-r--r--tensorflow/core/framework/types.cc7
-rw-r--r--tensorflow/core/framework/types.h2
-rw-r--r--tensorflow/core/framework/types.proto2
-rw-r--r--tensorflow/core/framework/variant.cc87
-rw-r--r--tensorflow/core/framework/variant.h287
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h220
-rw-r--r--tensorflow/core/framework/variant_test.cc249
-rw-r--r--tensorflow/core/platform/default/build_config.bzl8
-rw-r--r--tensorflow/core/platform/variant_coding.cc63
-rw-r--r--tensorflow/core/platform/variant_coding.h39
-rw-r--r--tensorflow/go/tensor.go1
-rw-r--r--tensorflow/python/__init__.py1
-rw-r--r--tensorflow/python/framework/dtypes.py17
-rw-r--r--tensorflow/python/framework/dtypes_test.py9
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt4
22 files changed, 1225 insertions, 9 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 3aeafb4685..46758408c4 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -117,6 +117,7 @@ typedef enum TF_DataType {
TF_COMPLEX128 = 18, // Double-precision complex
TF_HALF = 19,
TF_RESOURCE = 20,
+ TF_VARIANT = 21,
} TF_DataType;
// TF_DataTypeSize returns the sizeof() for the underlying type corresponding
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index a15621ba6d..b66eb35339 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -156,6 +156,11 @@ if (NOT tensorflow_ENABLE_GPU)
"${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*")
list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs})
endif()
+
+file(GLOB tf_core_platform_exclude_srcs
+ "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc")
+list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs})
+
list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs})
if(UNIX)
@@ -225,6 +230,8 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c
file(GLOB_RECURSE tf_core_framework_srcs
"${tensorflow_source_dir}/tensorflow/core/framework/*.h"
"${tensorflow_source_dir}/tensorflow/core/framework/*.cc"
+ "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h"
+ "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h"
"${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc"
"${tensorflow_source_dir}/tensorflow/core/graph/graph.h"
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 5885d4ed52..271b2492fe 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -94,6 +94,8 @@ load(
"tf_additional_lib_deps",
"tf_additional_lib_hdrs",
"tf_additional_lib_srcs",
+ "tf_additional_framework_hdrs",
+ "tf_additional_framework_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
@@ -343,6 +345,8 @@ tf_cuda_library(
hdrs = [
"example/feature_util.h",
"framework/allocator.h",
+ "framework/variant.h",
+ "framework/variant_encode_decode.h",
"framework/allocator_registry.h",
"framework/attr_value_util.h",
"framework/bfloat16.h",
@@ -1265,6 +1269,8 @@ LIB_INTERNAL_WINDOWS_DEPS = glob(
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
"platform/load_library.cc",
+ "platform/variant_coding.cc",
+ "platform/**/variant_cord_coding.cc",
],
)
@@ -1286,6 +1292,8 @@ cc_library(
],
exclude = [
"**/*test*",
+ "framework/variant.cc",
+ "platform/variant_coding.cc",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
@@ -1296,6 +1304,8 @@ cc_library(
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
"platform/**/gpu_tracer.cc",
+ "platform/variant_coding.cc",
+ "platform/**/variant_cord_coding.cc",
],
),
}) + tf_additional_lib_srcs(
@@ -1306,6 +1316,8 @@ cc_library(
"platform/**/stream_executor.h",
"platform/**/env_time.cc",
"platform/**/gpu_tracer.cc",
+ "platform/variant_coding.cc",
+ "platform/**/variant_cord_coding.cc",
] +
# Protobuf deps already included through the ":lib_proto_parsing"
# dependency.
@@ -1462,6 +1474,8 @@ tf_cuda_library(
"util/**/*.h",
"util/**/*.cc",
] + [
+ "platform/variant_coding.cc",
+ "platform/variant_coding.h",
"graph/edgeset.h",
"graph/edgeset.cc",
"graph/graph.h",
@@ -1490,20 +1504,22 @@ tf_cuda_library(
"util/memmapped_file_system_writer.h",
"util/memmapped_file_system_writer.cc",
],
- }),
+ }) + tf_additional_framework_srcs(),
hdrs = [
+ "framework/variant.h",
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
+ "platform/variant_coding.h",
"util/command_line_flags.h",
"util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
"util/tensor_slice_util.h",
- ],
+ ] + tf_additional_framework_hdrs(),
copts = tf_copts(),
linkopts = select({
"//tensorflow:freebsd": [],
@@ -1517,6 +1533,7 @@ tf_cuda_library(
":proto_text",
":protos_all_cc",
":version_lib",
+ "//tensorflow/core/platform/default/build_config:platformlib",
"//tensorflow/core/kernels:bounds_check",
"//third_party/eigen3",
] + if_mkl(["//third_party/mkl:intel_binary_blob"]),
@@ -2183,6 +2200,7 @@ tf_cc_tests(
"common_runtime/simple_placer_test.cc",
"example/feature_util_test.cc",
"framework/allocator_test.cc",
+ "framework/variant_test.cc",
"framework/attr_value_util_test.cc",
"framework/bfloat16_test.cc",
"framework/cancellation_test.cc",
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 1cb183d81e..868335b072 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -229,6 +230,14 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+ virtual void RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+ }
+
+ virtual void RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+ }
+
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
// allocation, number of bytes free on device, etc.)
@@ -256,6 +265,16 @@ inline void Allocator::RunDtor(ResourceHandle* p, size_t n) {
RunResourceDtor(p, n);
}
+template <>
+inline void Allocator::RunCtor(Variant* p, size_t n) {
+ RunVariantCtor(p, n);
+}
+
+template <>
+inline void Allocator::RunDtor(Variant* p, size_t n) {
+ RunVariantDtor(p, n);
+}
+
// An implementation of Allocator that delegates all calls to another Allocator.
//
// Useful to clients who want to override part of the functionality of another
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 80b98fb9c6..0a85894071 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -36,6 +36,8 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -47,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/platform/variant_coding.h"
namespace tensorflow {
namespace {
@@ -220,6 +223,36 @@ struct Helper<ResourceHandle> {
}
};
+template <>
+struct Helper<Variant> {
+ // Encodes "n" elements of type Variant stored in "in" into destination
+ // "out", which is usually the TensorProto::tensor_content.
+ template <typename Destination>
+ static void Encode(TensorBuffer* in, int64 n, Destination* out) {
+ port::EncodeVariantList(in->base<const Variant>(), n, out);
+ }
+
+ // Decodes "n" elements of type Variant from "in" and constructs a
+ // buffer out of it. Returns nullptr if the decoding fails. "in" is
+ // usually the TensorProto::tensor_content.
+ template <typename Source>
+ static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) {
+ auto* buf = new Buffer<Variant>(a, n);
+ Variant* ps = buf->template base<Variant>();
+ if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) {
+ buf->Unref();
+ return nullptr;
+ }
+ return buf;
+ }
+
+ // Returns the estimated memory usage of "n" elements of type T
+ // stored in buffer "in".
+ static int64 TotalBytes(TensorBuffer* in, int n) {
+ return n * sizeof(Variant);
+ }
+};
+
template <typename T>
struct ProtoHelper {};
@@ -290,6 +323,26 @@ struct ProtoHelper<ResourceHandle> {
};
template <>
+struct ProtoHelper<Variant> {
+ static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator
+ Begin(const TensorProto& proto) {
+ return proto.variant_val().begin();
+ }
+ static size_t NumElements(const TensorProto& proto) {
+ return proto.variant_val().size();
+ }
+ static void Fill(const Variant* data, size_t n, TensorProto* proto) {
+ auto* variant_values = proto->mutable_variant_val();
+ variant_values->Clear();
+ for (size_t i = 0; i < n; ++i) {
+ VariantTensorData tmp;
+ data[i].Encode(&tmp);
+ tmp.ToProto(variant_values->Add());
+ }
+ }
+};
+
+template <>
struct ProtoHelper<complex64> {
typedef Helper<float>::RepeatedFieldType FieldType;
static const complex64* Begin(const TensorProto& proto) {
@@ -421,6 +474,30 @@ TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) {
return buf;
}
+template <>
+TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in,
+ int64 n) {
+ CHECK_GT(n, 0);
+ Buffer<Variant>* buf = new Buffer<Variant>(a, n);
+ Variant* data = buf->template base<Variant>();
+ if (data == nullptr) {
+ buf->Unref();
+ return nullptr;
+ }
+ const int64 in_n = ProtoHelper<Variant>::NumElements(in);
+ if (in_n <= 0) {
+ std::fill_n(data, n, Variant());
+ } else {
+ for (int64 i = 0; i < in_n; ++i) {
+ data[i] = in.variant_val(i);
+ }
+ for (int64 i = in_n; i < n; ++i) {
+ data[i] = Variant();
+ }
+ }
+ return buf;
+}
+
// fp16 is opaque to the protobuf, so we deserialize these identical to uint16
// but with data stored in half_val instead of int_val (ie., we don't use
// ProtoHelper<uint16>).
@@ -571,6 +648,7 @@ bool Tensor::RefCountIsOne() const {
CASE(bfloat16, SINGLE_ARG(STMTS)) \
CASE(Eigen::half, SINGLE_ARG(STMTS)) \
CASE(ResourceHandle, SINGLE_ARG(STMTS)) \
+ CASE(Variant, SINGLE_ARG(STMTS)) \
case DT_INVALID: \
INVALID; \
break; \
diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto
index 98ecef225a..7e4af7a645 100644
--- a/tensorflow/core/framework/tensor.proto
+++ b/tensorflow/core/framework/tensor.proto
@@ -72,4 +72,17 @@ message TensorProto {
// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;
+
+ // DT_VARIANT
+ repeated VariantTensorDataProto variant_val = 15;
};
+
+// Protocol buffer representing the serialization format of DT_VARIANT tensors.
+message VariantTensorDataProto {
+ // Name of the type of objects being serialized.
+ string type_name = 1;
+ // Portions of the object that are not Tensors.
+ bytes metadata = 2;
+ // Tensors contained within objects being serialized.
+ repeated TensorProto tensors = 3;
+}
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 6c9c803af6..9aaf00853d 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -37,6 +38,35 @@ inline bool operator==(const ResourceHandle& a, const ResourceHandle& b) {
a.maybe_type_name() == b.maybe_type_name();
}
+inline bool operator==(const Variant& a, const Variant& b) {
+ if (a.is_empty()) {
+ return b.is_empty();
+ }
+
+ if (a.TypeId() != b.TypeId()) return false;
+ if (a.TypeName() != b.TypeName()) return false;
+
+ VariantTensorData a_data, b_data;
+ a.Encode(&a_data);
+ b.Encode(&b_data);
+
+ if (a_data.metadata != b_data.metadata) return false;
+
+ if (a_data.tensors.size() != b_data.tensors.size()) return false;
+
+ for (int i = 0; i < a_data.tensors.size(); ++i) {
+ TensorProto a_proto, b_proto;
+ a_data.tensors[i].AsProtoTensorContent(&a_proto);
+ b_data.tensors[i].AsProtoTensorContent(&b_proto);
+ string a_str, b_str;
+ a_proto.SerializeToString(&a_str);
+ b_proto.SerializeToString(&b_str);
+ if (a_str != b_str) return false;
+ }
+
+ return true;
+}
+
TEST(TensorTest, Default) {
Tensor t;
EXPECT_EQ(t.dtype(), DT_FLOAT);
@@ -159,6 +189,74 @@ TEST(Tensor_ResourceHandle, Simple) {
TestCopies<ResourceHandle>(t);
}
+TEST(Tensor_Variant, Simple) {
+ Tensor t(DT_VARIANT, TensorShape({}));
+ Tensor value(DT_FLOAT, TensorShape({}));
+ value.flat<float>()(0) = 42.0f;
+ t.flat<Variant>()(0) = value;
+ // All the tests in TestCopies except the ones that serialize and deserialize
+ // the tensor. The consumer of a serialized Variant Tensor should know what
+ // type is stored in the Tensor, so not testing the generic
+ // serialize/deserialize case here.
+ {
+ LOG(INFO) << "CopyFrom()";
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.CopyFrom(t, t.shape()));
+ test::ExpectTensorEqual<Variant>(t, t2);
+ }
+ {
+ LOG(INFO) << "operator=()";
+ Tensor t2(t.dtype());
+ t2 = t;
+ test::ExpectTensorEqual<Variant>(t, t2);
+ }
+ {
+ LOG(INFO) << "deep copy";
+ Tensor t2(t.dtype(), t.shape());
+ t2.flat<Variant>() = t.flat<Variant>();
+ test::ExpectTensorEqual<Variant>(t, t2);
+ }
+ {
+ LOG(INFO) << "AsTensor";
+ gtl::ArraySlice<Variant> values(t.flat<Variant>().data(), t.NumElements());
+ Tensor t2 = test::AsTensor(values, t.shape());
+ test::ExpectTensorEqual<Variant>(t, t2);
+ }
+ {
+ LOG(INFO) << "Move constructor";
+ Tensor t2 = t;
+ Tensor t3(std::move(t2));
+ test::ExpectTensorEqual<Variant>(t, t3);
+ EXPECT_TRUE(t3.IsInitialized());
+ EXPECT_FALSE(t2.IsInitialized());
+ }
+ {
+ LOG(INFO) << "Move assignment";
+ Tensor t2 = t;
+ Tensor t3 = std::move(t2);
+ Tensor* t4 = &t3;
+ *t4 = std::move(t3);
+ test::ExpectTensorEqual<Variant>(t, t3);
+ EXPECT_TRUE(t3.IsInitialized());
+ EXPECT_FALSE(t2.IsInitialized());
+ }
+}
+
+TEST(Tensor_Variant, Marshal) {
+ Tensor t(DT_VARIANT, TensorShape({}));
+
+ Tensor internal(DT_FLOAT, TensorShape({}));
+ internal.flat<float>()(0) = 42.0f;
+ t.flat<Variant>()(0) = internal;
+
+ LOG(INFO) << "AsProtoField()";
+ TensorProto proto;
+ t.AsProtoField(&proto);
+
+ Tensor t2(t.dtype());
+ EXPECT_TRUE(t2.FromProto(proto));
+}
+
TEST(Tensor_UInt16, Simple) {
Tensor t(DT_UINT16, TensorShape({2, 2}));
EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2})));
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index 28b5e8be11..39dd5b435e 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -87,6 +87,8 @@ string DataTypeString(DataType dtype) {
return "half";
case DT_RESOURCE:
return "resource";
+ case DT_VARIANT:
+ return "variant";
default:
LOG(ERROR) << "Unrecognized DataType enum value " << dtype;
return strings::StrCat("unknown dtype enum (", dtype, ")");
@@ -165,6 +167,9 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) {
} else if (sp == "resource") {
*dt = DT_RESOURCE;
return true;
+ } else if (sp == "variant") {
+ *dt = DT_VARIANT;
+ return true;
}
return false;
}
@@ -186,7 +191,7 @@ DataTypeVector AllTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
- DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE};
+ DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT};
}
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index 889056647c..9127750d68 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -181,6 +182,7 @@ MATCH_TYPE_AND_ENUM(qint32, DT_QINT32);
MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16);
MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF);
MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE);
+MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT);
#undef MATCH_TYPE_AND_ENUM
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index b80e2b31dc..1beb2a1aa2 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -34,6 +34,7 @@ enum DataType {
DT_COMPLEX128 = 18; // Double-precision complex
DT_HALF = 19;
DT_RESOURCE = 20;
+ DT_VARIANT = 21; // Arbitrary C++ data types
// TODO(josh11b): DT_GENERIC_PROTO = ??;
// TODO(jeff,josh11b): DT_UINT64? DT_UINT32?
@@ -60,5 +61,6 @@ enum DataType {
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
DT_RESOURCE_REF = 120;
+ DT_VARIANT_REF = 121;
}
// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
new file mode 100644
index 0000000000..6a11f86c23
--- /dev/null
+++ b/tensorflow/core/framework/variant.cc
@@ -0,0 +1,87 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+
+template <>
+void* Variant::get() {
+ if (is_empty()) {
+ return nullptr;
+ }
+ return value_->RawPtr();
+}
+
+template <>
+const void* Variant::get() const {
+ if (is_empty()) {
+ return nullptr;
+ }
+ return value_->RawPtr();
+}
+
+void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
+ proto->set_type_name(type_name);
+ proto->set_metadata(metadata);
+ proto->clear_tensors();
+ for (int i = 0; i < tensors.size(); ++i) {
+ tensors[i].AsProtoField(proto->mutable_tensors()->Add());
+ }
+}
+
+bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+ type_name = proto.type_name();
+ metadata = proto.metadata();
+ tensors.clear();
+ for (int i = 0; i < proto.tensors_size(); ++i) {
+ Tensor tmp;
+ if (!tmp.FromProto(proto.tensors(i))) return false;
+ tensors.push_back(tmp);
+ }
+ return true;
+}
+
+template <>
+string TypeNameVariant(VariantTensorDataProto&& value) {
+ return value.GetTypeName();
+}
+
+template <>
+void EncodeVariant(VariantTensorDataProto&& value, VariantTensorData* data) {
+ data->FromProto(value);
+}
+
+template <>
+bool DecodeVariant(const VariantTensorData& data,
+ VariantTensorDataProto* value) {
+ data.ToProto(value);
+ return true;
+}
+
+template <>
+void EncodeVariant(VariantTensorDataProto&& value, string* buf) {
+ value.SerializeToString(buf);
+}
+
+template <>
+bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(buf);
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
new file mode 100644
index 0000000000..df6d903143
--- /dev/null
+++ b/tensorflow/core/framework/variant.h
@@ -0,0 +1,287 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_
+#define TENSORFLOW_FRAMEWORK_VARIANT_H_
+
+#include <functional>
+#include <iostream>
+#include <memory>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+#include "tensorflow/core/framework/type_index.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+struct VariantTensorData;
+
+template <typename T>
+string TypeNameVariant(T&& value);
+
+template <typename T>
+void EncodeVariant(T&& value, VariantTensorData* data);
+
+template <typename T>
+bool DecodeVariant(const VariantTensorData& data, T* value);
+
+template <typename T>
+void EncodeVariant(T&& value, string* buf);
+
+template <typename T>
+bool DecodeVariant(const string& buf, T* value);
+
+// This is an implementation of a type-erased container that can store an
+// object of any type. The implementation is very similar to std::any, but has
+// restrictions on the types of objects that can be stored, and eschews some of
+// the fancier constructors available for std::any. An object of
+// tensorflow::Variant is intended to be used as the value that will be stored
+// in a tensorflow::Tensor object when its type is DT_VARIANT.
+//
+// tensorflow::Variant can store an object of a class that satisfies the
+// following constraints:
+//
+// * The class is CopyConstructible.
+// * The class has a default constructor.
+// * It's either a protocol buffer, a tensorflow::Tensor, or defines the
+// following functions:
+//
+// string TypeName() const;
+// void Encode(VariantTensorData* data) const;
+// void Decode(const VariantTensorData& data);
+//
+// Simple POD types can elide the Encode/Decode functions, they are provided by
+// helper methods.
+// Here are some typical usage patterns:
+//
+// Variant x = 10;
+// EXPECT_EQ(*x.get<int>(), 10);
+//
+// Tensor t(DT_FLOAT, TensorShape({}));
+// t.flat<float>()(0) = 42.0f;
+// Variant x = t;
+// EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
+//
+// Accessing the stored object:
+//
+// The get<T> function is the main mechanism to access the object stored in the
+// contained. It is type-safe, that is, calling get<T> when the stored object's
+// type is not T, returns a nullptr. A raw pointer to the stored object can be
+// obtained by calling get<void>().
+//
+// Serializing/deserializing Variant object:
+//
+// The Variant class delegates serializing and deserializing operations to the
+// contained object. Helper functions to do these operations are provided for
+// POD data types, tensorflow::Tensor, and protocol buffer objects. However,
+// other classes have to provide Encode/Decode functions to do handle
+// serialization.
+//
+// Objects stored in a Variant object often contain references to other
+// tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors).
+// To efficiently support those use cases, a structure is imposed on the
+// serialization format. Namely, classes should serialize their contents in to a
+// VariantTensorData object:
+//
+// struct VariantTensorData {
+// string type_name;
+// string metadata;
+// std::vector<Tensor> tensors;
+// };
+//
+// Objects with references to other Tensors can simply store those tensors in
+// the `tensors` field, and serialize other metadata content in to the
+// `metadata` field.
+//
+// Serialization example:
+//
+// Foo f = Foo {...};
+// Variant x = f;
+// string serialized_f;
+// x.Encode(&serialized_f);
+//
+// Variant y = Foo(); // default constructed Foo.
+// y.Decode(&serialized_f);
+// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
+//
+class Variant {
+ public:
+ constexpr Variant() noexcept = default;
+
+ Variant(const Variant& other)
+ : value_(other.is_empty() ? std::unique_ptr<ValueInterface>()
+ : other.value_->Clone()) {}
+
+ Variant(Variant&& other) noexcept = default;
+
+ // Make sure that the type is CopyConstructible and not a tensorflow::Variant
+ // object itself. We want the copy constructor to be chosen for the
+ // tensorflow::Variant case.
+ template <typename T, typename VT = typename std::decay<T>::type,
+ typename std::enable_if<!std::is_same<Variant, VT>::value &&
+ std::is_copy_constructible<VT>::value,
+ void>::type* = nullptr>
+ Variant(T&& value) // NOLINT
+ : value_(new Value<VT>(in_place, std::forward<T>(value))) {}
+
+ Variant& operator=(const Variant& rhs) {
+ Variant(rhs).swap(*this);
+ return *this;
+ }
+
+ Variant& operator=(Variant&& rhs) noexcept {
+ Variant(std::move(rhs)).swap(*this);
+ return *this;
+ }
+
+ bool is_empty() const { return value_ == nullptr; }
+
+ void clear() noexcept { value_.reset(); }
+
+ void swap(Variant& other) noexcept { value_.swap(other.value_); }
+
+ TypeIndex TypeId() const {
+ const TypeIndex VoidTypeIndex = MakeTypeIndex<void>();
+ if (is_empty()) {
+ return VoidTypeIndex;
+ }
+ return value_->TypeId();
+ }
+
+ template <typename T>
+ T* get() {
+ const TypeIndex TTypeIndex = MakeTypeIndex<T>();
+ if (is_empty() || (TTypeIndex != TypeId())) {
+ return nullptr;
+ }
+ return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value);
+ }
+
+ template <typename T>
+ const T* get() const {
+ const TypeIndex TTypeIndex = MakeTypeIndex<T>();
+ if (is_empty() || (TTypeIndex != TypeId())) {
+ return nullptr;
+ }
+ return std::addressof(
+ static_cast<const Variant::Value<T>*>(value_.get())->value);
+ }
+
+ string TypeName() const {
+ if (is_empty()) {
+ return "";
+ }
+ return value_->TypeName();
+ }
+
+ // Serialize the contents of the stored object into `data`.
+ void Encode(VariantTensorData* data) const {
+ if (!is_empty()) {
+ value_->Encode(data);
+ }
+ }
+
+ // Deserialize `data` and update the stored object.
+ bool Decode(const VariantTensorData& data) {
+ if (!is_empty()) {
+ return value_->Decode(data);
+ }
+ return true;
+ }
+
+ // Helper methods to directly serialize/deserialize from strings.
+ void Encode(string* buf) const {
+ if (!is_empty()) {
+ value_->Encode(buf);
+ }
+ }
+ bool Decode(const string& buf) {
+ if (!is_empty()) {
+ return value_->Decode(buf);
+ }
+ return true;
+ }
+
+ private:
+ struct in_place_t {};
+ static constexpr in_place_t in_place{};
+
+ struct ValueInterface {
+ virtual ~ValueInterface() = default;
+ virtual TypeIndex TypeId() const = 0;
+ virtual void* RawPtr() = 0;
+ virtual const void* RawPtr() const = 0;
+ virtual std::unique_ptr<ValueInterface> Clone() const = 0;
+ virtual string TypeName() const = 0;
+ virtual void Encode(VariantTensorData* data) const = 0;
+ virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual void Encode(string* buf) const = 0;
+ virtual bool Decode(const string& data) = 0;
+ };
+
+ template <typename T>
+ struct Value : ValueInterface {
+ template <class... Args>
+ explicit Value(in_place_t /*tag*/, Args&&... args)
+ : value(std::forward<Args>(args)...) {}
+
+ TypeIndex TypeId() const override {
+ const TypeIndex value_type_index =
+ MakeTypeIndex<typename std::decay<T>::type>();
+ return value_type_index;
+ }
+
+ void* RawPtr() override { return &value; }
+
+ const void* RawPtr() const override { return &value; }
+
+ std::unique_ptr<ValueInterface> Clone() const override {
+ return std::unique_ptr<ValueInterface>(new Value(in_place, value));
+ }
+
+ string TypeName() const override { return TypeNameVariant(value); }
+
+ void Encode(VariantTensorData* data) const override {
+ EncodeVariant(value, data);
+ }
+
+ bool Decode(const VariantTensorData& data) override {
+ return DecodeVariant(data, &value);
+ }
+
+ void Encode(string* buf) const override { EncodeVariant(value, buf); }
+
+ bool Decode(const string& buf) override {
+ return DecodeVariant(buf, &value);
+ }
+
+ T value;
+ };
+
+ std::unique_ptr<ValueInterface> value_;
+};
+
+template <>
+void* Variant::get();
+
+template <>
+const void* Variant::get() const;
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_VARIANT_H_
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
new file mode 100644
index 0000000000..09309aa7e5
--- /dev/null
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -0,0 +1,220 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#define TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+
+#include <iostream>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+
+// The serialization format for Variant objects. Objects with references to
+// other Tensors can simply store those tensors in the `tensors` field, and
+// serialize other metadata content in to the `metadata` field. Objects can
+// optionally set the `type_name` for type-checking before deserializing an
+// object.
+struct VariantTensorData {
+ string type_name;
+ string metadata;
+ std::vector<Tensor> tensors;
+ void ToProto(VariantTensorDataProto* proto) const;
+ bool FromProto(const VariantTensorDataProto& proto);
+};
+
+// Type used for tag-dispatch of the Encode/Decode Variant implementations. This
+// template can determine whether the first type parameter `T` is one of the
+// following:
+//
+// * A POD type (TypeResolver<T, true>)
+// * A tensorflow::Tensor (TypeResolver<T, false, true>)
+// * A protocol buffer (TypeResolver<T, false, false, true>)
+// * None of the above (TypeResolver<T, false, false, false>)
+//
+template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value,
+ bool = std::is_same<typename std::decay<T>::type,
+ ::tensorflow::Tensor>::value,
+ bool = std::is_base_of<protobuf::MessageLite,
+ typename std::decay<T>::type>::value>
+struct TypeResolver {};
+
+// Specialization for POD type
+template <typename T>
+void EncodeVariantImpl(T&& value, TypeResolver<T, true /* is_pod */>,
+ VariantTensorData* data) {
+ data->metadata.assign(reinterpret_cast<const char*>(&value), sizeof(value));
+}
+
+// Specialization for tensorflow::Tensor
+template <typename T>
+void EncodeVariantImpl(T&& value,
+ TypeResolver<T, false /* is_pod */, true /* Tensor */>,
+ VariantTensorData* data) {
+ data->tensors.clear();
+ data->tensors.push_back(value);
+}
+
+// Specialization for protobuf
+template <typename T>
+void EncodeVariantImpl(T&& value,
+ TypeResolver<T, false /* is_pod */, false /* Tensor */,
+ true /* protobuf */>,
+ VariantTensorData* data) {
+ value.SerializeToString(&data->metadata);
+}
+
+// Specialization for other types
+template <typename T>
+void EncodeVariantImpl(T&& value,
+ TypeResolver<T, false /* is_pod */, false /* Tensor */,
+ false /* protobuf */>,
+ VariantTensorData* data) {
+ value.Encode(data);
+}
+
+// Specialization for POD type
+template <typename T>
+bool DecodeVariantImpl(const VariantTensorData& data,
+ TypeResolver<T, true /* is_pod */>, T* value) {
+ std::copy_n(data.metadata.data(), sizeof(*value),
+ reinterpret_cast<char*>(value));
+ return true;
+}
+
+// Specialization for tensorflow::Tensor
+template <typename T>
+bool DecodeVariantImpl(const VariantTensorData& data,
+ TypeResolver<T, false /* is_pod */, true /* Tensor */>,
+ T* value) {
+ *value = data.tensors[0];
+ return true;
+}
+
+// Specialization for protobuf
+template <typename T>
+bool DecodeVariantImpl(const VariantTensorData& data,
+ TypeResolver<T, false /* is_pod */, false /* Tensor */,
+ true /* protobuf */>,
+ T* value) {
+ return value->ParseFromString(data.metadata);
+}
+
+// Specialization for other types
+template <typename T>
+bool DecodeVariantImpl(const VariantTensorData& data,
+ TypeResolver<T, false /* is_pod */, false /* Tensor */,
+ false /* protobuf */>,
+ T* value) {
+ return value->Decode(data);
+}
+
+template <typename C, typename = void>
+struct has_type_name : std::false_type {};
+
+template <typename C>
+struct has_type_name<
+ C, typename std::enable_if<std::is_same<
+ decltype(std::declval<C>().TypeName()), string>::value>::type>
+ : std::true_type {};
+
+template <typename T, bool = has_type_name<typename std::decay<T>::type>::value,
+ bool = std::is_same<typename std::decay<T>::type,
+ ::tensorflow::Tensor>::value,
+ bool = std::is_base_of<protobuf::MessageLite,
+ typename std::decay<T>::type>::value>
+struct TypeNameResolver {};
+
+template <typename T>
+string TypeNameVariantImpl(T&& value,
+ TypeNameResolver<T, true /* has_type_name */>) {
+ return value.TypeName();
+}
+
+template <typename T>
+string TypeNameVariantImpl(
+ T&& value,
+ TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) {
+ return "tensorflow::Tensor";
+}
+
+template <typename T>
+string TypeNameVariantImpl(
+ T&& value, TypeNameResolver<T, false /* has_type_name */,
+ false /* Tensor */, true /* protobuf */>) {
+ return value.GetTypeName();
+}
+
+template <typename T>
+string TypeNameVariantImpl(
+ T&& value, TypeNameResolver<T, false /* has_type_name */,
+ false /* Tensor */, false /* protobuf */>) {
+ return value.TypeName();
+}
+
+template <typename T>
+string TypeNameVariant(T&& value) {
+ return TypeNameVariantImpl(std::forward<T>(value), TypeNameResolver<T>());
+}
+
+template <typename T>
+void EncodeVariant(T&& value, VariantTensorData* data) {
+ EncodeVariantImpl(std::forward<T>(value), TypeResolver<T>(), data);
+}
+
+template <typename T>
+bool DecodeVariant(const VariantTensorData& data, T* value) {
+ return DecodeVariantImpl(data, TypeResolver<T>(), value);
+}
+
+template <typename T>
+void EncodeVariant(T&& value, string* buf) {
+ VariantTensorData data;
+ EncodeVariantImpl(std::forward<T>(value), TypeResolver<T>(), &data);
+ VariantTensorDataProto proto;
+ data.ToProto(&proto);
+ proto.SerializeToString(buf);
+}
+
+template <typename T>
+bool DecodeVariant(const string& buf, T* value) {
+ VariantTensorDataProto proto;
+ if (!proto.ParseFromString(buf)) return false;
+ VariantTensorData data;
+ if (!data.FromProto(proto)) return false;
+ if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ return true;
+}
+
+// Specializations for VariantTensorDataProto
+template <>
+string TypeNameVariant(VariantTensorDataProto&& value);
+template <>
+void EncodeVariant(VariantTensorDataProto&& value, VariantTensorData* data);
+template <>
+bool DecodeVariant(const VariantTensorData& data,
+ VariantTensorDataProto* value);
+template <>
+void EncodeVariant(VariantTensorDataProto&& value, string* buf);
+template <>
+bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
new file mode 100644
index 0000000000..c7ffdd28f4
--- /dev/null
+++ b/tensorflow/core/framework/variant_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_encode_decode.h"
+
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+template <typename T>
+struct Wrapper {
+ T value;
+ string TypeName() const { return "POD"; }
+};
+
+using Int = Wrapper<int>;
+using Float = Wrapper<float>;
+
+} // end namespace
+
+TEST(VariantTest, Basic) {
+ Variant x;
+ EXPECT_EQ(x.get<void>(), nullptr);
+
+ x = Int{42};
+
+ EXPECT_NE(x.get<void>(), nullptr);
+ EXPECT_NE(x.get<Int>(), nullptr);
+ EXPECT_EQ(x.get<Int>()->value, 42);
+ EXPECT_EQ(x.TypeName(), "POD");
+}
+
+TEST(VariantTest, ConstGet) {
+ Variant x;
+ EXPECT_EQ(x.get<void>(), nullptr);
+
+ x = Int{42};
+
+ const Variant y = x;
+
+ EXPECT_NE(y.get<void>(), nullptr);
+ EXPECT_NE(y.get<Int>(), nullptr);
+ EXPECT_EQ(y.get<Int>()->value, 42);
+}
+
+TEST(VariantTest, Clear) {
+ Variant x;
+ EXPECT_EQ(x.get<void>(), nullptr);
+
+ x = Int{42};
+
+ EXPECT_NE(x.get<void>(), nullptr);
+ EXPECT_NE(x.get<Int>(), nullptr);
+ EXPECT_EQ(x.get<Int>()->value, 42);
+
+ x.clear();
+ EXPECT_EQ(x.get<void>(), nullptr);
+}
+
+TEST(VariantTest, Tensor) {
+ Variant x;
+ Tensor t(DT_FLOAT, {});
+ t.flat<float>()(0) = 42.0f;
+ x = t;
+
+ EXPECT_NE(x.get<Tensor>(), nullptr);
+ EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f);
+ x.get<Tensor>()->flat<float>()(0) += 1.0f;
+ EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 43.0f);
+ EXPECT_EQ(x.TypeName(), "tensorflow::Tensor");
+}
+
+TEST(VariantTest, TensorProto) {
+ Variant x;
+ TensorProto t;
+ t.set_dtype(DT_FLOAT);
+ t.mutable_tensor_shape()->set_unknown_rank(true);
+ x = t;
+
+ EXPECT_EQ(x.TypeName(), "tensorflow.TensorProto");
+ EXPECT_NE(x.get<TensorProto>(), nullptr);
+ EXPECT_EQ(x.get<TensorProto>()->dtype(), DT_FLOAT);
+ EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true);
+}
+
+TEST(VariantTest, CopyValue) {
+ Variant x, y;
+ x = Int{10};
+ y = x;
+
+ EXPECT_EQ(x.get<Int>()->value, 10);
+ EXPECT_EQ(x.get<Int>()->value, y.get<Int>()->value);
+}
+
+TEST(VariantTest, MoveValue) {
+ Variant x;
+ x = []() -> Variant {
+ Variant y;
+ y = Int{10};
+ return y;
+ }();
+ EXPECT_EQ(x.get<Int>()->value, 10);
+}
+
+TEST(VariantTest, TypeMismatch) {
+ Variant x;
+ x = Int{10};
+ EXPECT_EQ(x.get<float>(), nullptr);
+ EXPECT_EQ(x.get<int>(), nullptr);
+ EXPECT_NE(x.get<Int>(), nullptr);
+}
+
+struct TensorList {
+ void Encode(VariantTensorData* data) const { data->tensors = vec; }
+
+ bool Decode(const VariantTensorData& data) {
+ vec = data.tensors;
+ return true;
+ }
+
+ string TypeName() const { return "TensorList"; }
+
+ std::vector<Tensor> vec;
+};
+
+TEST(VariantTest, TensorListTest) {
+ Variant x;
+
+ TensorList vec;
+ for (int i = 0; i < 4; ++i) {
+ Tensor elem(DT_INT32, {1});
+ elem.flat<int>()(0) = i;
+ vec.vec.push_back(elem);
+ }
+
+ for (int i = 0; i < 4; ++i) {
+ Tensor elem(DT_FLOAT, {1});
+ elem.flat<float>()(0) = 2 * i;
+ vec.vec.push_back(elem);
+ }
+
+ x = vec;
+
+ EXPECT_EQ(x.TypeName(), "TensorList");
+ const TensorList& stored_vec = *x.get<TensorList>();
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(stored_vec.vec[i].flat<int>()(0), i);
+ }
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(stored_vec.vec[i + 4].flat<float>()(0), 2 * i);
+ }
+
+ VariantTensorData serialized;
+ x.Encode(&serialized);
+
+ Variant y = TensorList();
+ y.Decode(serialized);
+
+ const TensorList& decoded_vec = *x.get<TensorList>();
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(decoded_vec.vec[i].flat<int>()(0), i);
+ }
+ for (int i = 0; i < 4; ++i) {
+ EXPECT_EQ(decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
+ }
+}
+
+TEST(VariantTest, VariantArray) {
+ Variant x[2];
+ x[0] = Int{2};
+ x[1] = Float{2.0f};
+
+ EXPECT_EQ(x[0].get<Int>()->value, 2);
+ EXPECT_EQ(x[1].get<Float>()->value, 2.0f);
+}
+
+TEST(VariantTest, PodUpdate) {
+ struct Pod {
+ int x;
+ float y;
+
+ string TypeName() const { return "POD"; }
+ };
+
+ Variant x = Pod{10, 20.f};
+ EXPECT_NE(x.get<Pod>(), nullptr);
+ EXPECT_EQ(x.TypeName(), "POD");
+
+ x.get<Pod>()->x += x.get<Pod>()->y;
+ EXPECT_EQ(x.get<Pod>()->x, 30);
+}
+
+TEST(VariantTest, EncodeDecodePod) {
+ struct Pod {
+ int x;
+ float y;
+
+ string TypeName() const { return "POD"; }
+ };
+
+ Variant x;
+ Pod p{10, 20.0f};
+ x = p;
+
+ VariantTensorData serialized;
+ x.Encode(&serialized);
+
+ Variant y;
+ y = Pod();
+ y.Decode(serialized);
+
+ EXPECT_EQ(p.x, y.get<Pod>()->x);
+ EXPECT_EQ(p.y, y.get<Pod>()->y);
+}
+
+TEST(VariantTest, EncodeDecodeTensor) {
+ Variant x;
+ Tensor t(DT_INT32, {});
+ t.flat<int>()(0) = 42;
+ x = t;
+
+ VariantTensorData serialized;
+ x.Encode(&serialized);
+
+ Variant y = Tensor();
+ y.Decode(serialized);
+ EXPECT_EQ(x.get<Tensor>()->flat<int>()(0), y.get<Tensor>()->flat<int>()(0));
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index d12ad8a04f..5db0fe423c 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -130,6 +130,14 @@ def tf_additional_lib_srcs(exclude = []):
], exclude = exclude),
})
+# pylint: disable=unused-argument
+def tf_additional_framework_hdrs(exclude = []):
+ return []
+
+def tf_additional_framework_srcs(exclude = []):
+ return []
+# pylint: enable=unused-argument
+
def tf_additional_minimal_lib_srcs():
return [
"platform/default/integral_types.h",
diff --git a/tensorflow/core/platform/variant_coding.cc b/tensorflow/core/platform/variant_coding.cc
new file mode 100644
index 0000000000..4bcde4f581
--- /dev/null
+++ b/tensorflow/core/platform/variant_coding.cc
@@ -0,0 +1,63 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/variant_coding.h"
+
+#include <vector>
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace port {
+
+void EncodeVariantList(const Variant* variant_array, int64 n, string* out) {
+ out->clear();
+ string rest;
+ for (int i = 0; i < n; ++i) {
+ string s;
+ variant_array[i].Encode(&s);
+ core::PutVarint32(out, s.length());
+ strings::StrAppend(&rest, s);
+ }
+ strings::StrAppend(out, rest);
+}
+
+bool DecodeVariantList(const string& in, Variant* variant_array, int64 n) {
+ std::vector<uint32> sizes(n);
+ StringPiece reader(in);
+ int64 total = 0;
+ for (auto& size : sizes) {
+ if (!core::GetVarint32(&reader, &size)) return false;
+ total += size;
+ }
+ if (total != static_cast<int64>(reader.size())) {
+ return false;
+ }
+
+ for (int i = 0; i < n; ++i) {
+ if (variant_array[i].is_empty()) {
+ variant_array[i] = VariantTensorDataProto();
+ }
+ string str(reader.data(), sizes[i]);
+ if (!variant_array[i].Decode(str)) return false;
+ reader.remove_prefix(sizes[i]);
+ }
+ return true;
+}
+
+} // end namespace port
+} // end namespace tensorflow
diff --git a/tensorflow/core/platform/variant_coding.h b/tensorflow/core/platform/variant_coding.h
new file mode 100644
index 0000000000..34c2481149
--- /dev/null
+++ b/tensorflow/core/platform/variant_coding.h
@@ -0,0 +1,39 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PLATFORM_VARIANT_CODING_H_
+#define TENSORFLOW_PLATFORM_VARIANT_CODING_H_
+
+#include "tensorflow/core/framework/variant.h"
+
+#ifdef PLATFORM_GOOGLE
+#include "tensorflow/core/platform/google/variant_cord_coding.h"
+#endif
+
+namespace tensorflow {
+namespace port {
+
+// Encodes an array of Variant objects in to the given string.
+// `variant_array` is assumed to point to an array of `n` Variant objects.
+void EncodeVariantList(const Variant* variant_array, int64 n, string* out);
+
+// Decodes an array of Variant objects from the given string.
+// `variant_array` is assumed to point to an array of `n` Variant objects.
+bool DecodeVariantList(const string& in, Variant* variant_array, int64 n);
+
+} // end namespace port
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_PLATFORM_VARIANT_CODING_H_
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 4a60c736b5..8b8909a6f8 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -227,6 +227,7 @@ var types = []struct {
{reflect.TypeOf(uint16(0)), C.TF_UINT16},
{reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128},
// TODO(apassos): support DT_RESOURCE representation in go.
+ // TODO(keveman): support DT_VARIANT representation in go.
}
// shapeAndDataTypeOf returns the data type and shape of the Tensor
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 827e4c53ee..2fd8fc8688 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -208,6 +208,7 @@ _allowed_symbols.extend([
'uint16',
'uint8',
'resource',
+ 'variant',
])
# Export modules and constants.
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 00bfae213a..43535a593e 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -48,6 +48,7 @@ class DType(object):
* `tf.quint16`: Quantized 16-bit unsigned integer.
* `tf.qint32`: Quantized 32-bit signed integer.
* `tf.resource`: Handle to a mutable resource.
+ * `tf.variant`: Values of arbitrary types.
In addition, variants of these types with the `_ref` suffix are
defined for reference-typed tensors.
@@ -113,8 +114,11 @@ class DType(object):
@property
def is_numpy_compatible(self):
- return (self._type_enum != types_pb2.DT_RESOURCE and
- self._type_enum != types_pb2.DT_RESOURCE_REF)
+ numpy_incompatible = [types_pb2.DT_VARIANT,
+ types_pb2.DT_VARIANT_REF,
+ types_pb2.DT_RESOURCE,
+ types_pb2.DT_RESOURCE_REF]
+ return self._type_enum not in numpy_incompatible
@property
def as_numpy_dtype(self):
@@ -284,7 +288,8 @@ class DType(object):
@property
def size(self):
- if self._type_enum == types_pb2.DT_RESOURCE:
+ if (self._type_enum == types_pb2.DT_VARIANT or
+ self._type_enum == types_pb2.DT_RESOURCE):
return 1
return np.dtype(self.as_numpy_dtype).itemsize
@@ -304,6 +309,7 @@ dtype_range = {np.bool_: (False, True),
# Define standard wrappers for the types_pb2.DataType enum.
resource = DType(types_pb2.DT_RESOURCE)
+variant = DType(types_pb2.DT_VARIANT)
float16 = DType(types_pb2.DT_HALF)
half = float16
float32 = DType(types_pb2.DT_FLOAT)
@@ -325,6 +331,7 @@ qint16 = DType(types_pb2.DT_QINT16)
quint16 = DType(types_pb2.DT_QUINT16)
qint32 = DType(types_pb2.DT_QINT32)
resource_ref = DType(types_pb2.DT_RESOURCE_REF)
+variant_ref = DType(types_pb2.DT_VARIANT_REF)
bfloat16 = DType(types_pb2.DT_BFLOAT16)
float16_ref = DType(types_pb2.DT_HALF_REF)
half_ref = float16_ref
@@ -372,6 +379,7 @@ _INTERN_TABLE = {
types_pb2.DT_QINT32: qint32,
types_pb2.DT_BFLOAT16: bfloat16,
types_pb2.DT_RESOURCE: resource,
+ types_pb2.DT_VARIANT: variant,
types_pb2.DT_HALF_REF: float16_ref,
types_pb2.DT_FLOAT_REF: float32_ref,
types_pb2.DT_DOUBLE_REF: float64_ref,
@@ -392,6 +400,7 @@ _INTERN_TABLE = {
types_pb2.DT_QINT32_REF: qint32_ref,
types_pb2.DT_BFLOAT16_REF: bfloat16_ref,
types_pb2.DT_RESOURCE_REF: resource_ref,
+ types_pb2.DT_VARIANT_REF: variant_ref,
}
@@ -417,6 +426,7 @@ _TYPE_TO_STRING = {
types_pb2.DT_QINT32: "qint32",
types_pb2.DT_BFLOAT16: "bfloat16",
types_pb2.DT_RESOURCE: "resource",
+ types_pb2.DT_VARIANT: "variant",
types_pb2.DT_HALF_REF: "float16_ref",
types_pb2.DT_FLOAT_REF: "float32_ref",
types_pb2.DT_DOUBLE_REF: "float64_ref",
@@ -437,6 +447,7 @@ _TYPE_TO_STRING = {
types_pb2.DT_QINT32_REF: "qint32_ref",
types_pb2.DT_BFLOAT16_REF: "bfloat16_ref",
types_pb2.DT_RESOURCE_REF: "resource_ref",
+ types_pb2.DT_VARIANT_REF: "variant_ref",
}
_STRING_TO_TF = {value: _INTERN_TABLE[key]
for key, value in _TYPE_TO_STRING.items()}
diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py
index d5a1e61a6f..1e84f1b656 100644
--- a/tensorflow/python/framework/dtypes_test.py
+++ b/tensorflow/python/framework/dtypes_test.py
@@ -27,9 +27,12 @@ from tensorflow.python.platform import googletest
def _is_numeric_dtype_enum(datatype_enum):
- return (datatype_enum != types_pb2.DT_INVALID and
- datatype_enum != types_pb2.DT_RESOURCE and
- datatype_enum != types_pb2.DT_RESOURCE_REF)
+ non_numeric_dtypes = [types_pb2.DT_VARIANT,
+ types_pb2.DT_VARIANT_REF,
+ types_pb2.DT_INVALID,
+ types_pb2.DT_RESOURCE,
+ types_pb2.DT_RESOURCE_REF]
+ return datatype_enum not in non_numeric_dtypes
class TypesTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 6a035aba63..314449bb73 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -493,6 +493,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "variant"
+ mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
+ }
+ member {
name: "zeros_initializer"
mtype: "<type \'type\'>"
}