diff options
22 files changed, 1225 insertions, 9 deletions
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 3aeafb4685..46758408c4 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -117,6 +117,7 @@ typedef enum TF_DataType { TF_COMPLEX128 = 18, // Double-precision complex TF_HALF = 19, TF_RESOURCE = 20, + TF_VARIANT = 21, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index a15621ba6d..b66eb35339 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -156,6 +156,11 @@ if (NOT tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/platform/default/cuda_libdevice_path.*") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) endif() + +file(GLOB tf_core_platform_exclude_srcs + "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc") +list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_exclude_srcs}) + list(APPEND tf_core_lib_srcs ${tf_core_platform_srcs}) if(UNIX) @@ -225,6 +230,8 @@ set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.c file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/framework/*.h" "${tensorflow_source_dir}/tensorflow/core/framework/*.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.h" + "${tensorflow_source_dir}/tensorflow/core/platform/variant_coding.cc" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.h" "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 5885d4ed52..271b2492fe 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -94,6 +94,8 @@ load( "tf_additional_lib_deps", "tf_additional_lib_hdrs", "tf_additional_lib_srcs", + "tf_additional_framework_hdrs", + "tf_additional_framework_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", @@ -343,6 +345,8 @@ tf_cuda_library( hdrs = [ "example/feature_util.h", "framework/allocator.h", + "framework/variant.h", + "framework/variant_encode_decode.h", "framework/allocator_registry.h", "framework/attr_value_util.h", "framework/bfloat16.h", @@ -1265,6 +1269,8 @@ LIB_INTERNAL_WINDOWS_DEPS = glob( "platform/**/cuda_libdevice_path.cc", "platform/**/stream_executor.h", "platform/load_library.cc", + "platform/variant_coding.cc", + "platform/**/variant_cord_coding.cc", ], ) @@ -1286,6 +1292,8 @@ cc_library( ], exclude = [ "**/*test*", + "framework/variant.cc", + "platform/variant_coding.cc", "lib/hash/crc32c_accelerate.cc", "lib/gif/**/*", "lib/jpeg/**/*", @@ -1296,6 +1304,8 @@ cc_library( "platform/**/cuda_libdevice_path.cc", "platform/**/stream_executor.h", "platform/**/gpu_tracer.cc", + "platform/variant_coding.cc", + "platform/**/variant_cord_coding.cc", ], ), }) + tf_additional_lib_srcs( @@ -1306,6 +1316,8 @@ cc_library( "platform/**/stream_executor.h", "platform/**/env_time.cc", "platform/**/gpu_tracer.cc", + "platform/variant_coding.cc", + "platform/**/variant_cord_coding.cc", ] + # Protobuf deps already included through the ":lib_proto_parsing" # dependency. @@ -1462,6 +1474,8 @@ tf_cuda_library( "util/**/*.h", "util/**/*.cc", ] + [ + "platform/variant_coding.cc", + "platform/variant_coding.h", "graph/edgeset.h", "graph/edgeset.cc", "graph/graph.h", @@ -1490,20 +1504,22 @@ tf_cuda_library( "util/memmapped_file_system_writer.h", "util/memmapped_file_system_writer.cc", ], - }), + }) + tf_additional_framework_srcs(), hdrs = [ + "framework/variant.h", "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", + "platform/variant_coding.h", "util/command_line_flags.h", "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", "util/tensor_slice_util.h", - ], + ] + tf_additional_framework_hdrs(), copts = tf_copts(), linkopts = select({ "//tensorflow:freebsd": [], @@ -1517,6 +1533,7 @@ tf_cuda_library( ":proto_text", ":protos_all_cc", ":version_lib", + "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/kernels:bounds_check", "//third_party/eigen3", ] + if_mkl(["//third_party/mkl:intel_binary_blob"]), @@ -2183,6 +2200,7 @@ tf_cc_tests( "common_runtime/simple_placer_test.cc", "example/feature_util_test.cc", "framework/allocator_test.cc", + "framework/variant_test.cc", "framework/attr_value_util_test.cc", "framework/bfloat16_test.cc", "framework/cancellation_test.cc", diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 1cb183d81e..868335b072 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -229,6 +230,14 @@ class Allocator { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } + virtual void RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); + } + + virtual void RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); + } + // TODO(jeff): Maybe provide some interface to give info about // current allocation state (total number of bytes available for // allocation, number of bytes free on device, etc.) @@ -256,6 +265,16 @@ inline void Allocator::RunDtor(ResourceHandle* p, size_t n) { RunResourceDtor(p, n); } +template <> +inline void Allocator::RunCtor(Variant* p, size_t n) { + RunVariantCtor(p, n); +} + +template <> +inline void Allocator::RunDtor(Variant* p, size_t n) { + RunVariantDtor(p, n); +} + // An implementation of Allocator that delegates all calls to another Allocator. // // Useful to clients who want to override part of the functionality of another diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 80b98fb9c6..0a85894071 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -36,6 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor_description.pb.h" #include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -47,6 +49,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/platform/variant_coding.h" namespace tensorflow { namespace { @@ -220,6 +223,36 @@ struct Helper<ResourceHandle> { } }; +template <> +struct Helper<Variant> { + // Encodes "n" elements of type Variant stored in "in" into destination + // "out", which is usually the TensorProto::tensor_content. + template <typename Destination> + static void Encode(TensorBuffer* in, int64 n, Destination* out) { + port::EncodeVariantList(in->base<const Variant>(), n, out); + } + + // Decodes "n" elements of type Variant from "in" and constructs a + // buffer out of it. Returns nullptr if the decoding fails. "in" is + // usually the TensorProto::tensor_content. + template <typename Source> + static TensorBuffer* Decode(Allocator* a, const Source& in, int64 n) { + auto* buf = new Buffer<Variant>(a, n); + Variant* ps = buf->template base<Variant>(); + if (ps == nullptr || !port::DecodeVariantList(in, ps, n)) { + buf->Unref(); + return nullptr; + } + return buf; + } + + // Returns the estimated memory usage of "n" elements of type T + // stored in buffer "in". + static int64 TotalBytes(TensorBuffer* in, int n) { + return n * sizeof(Variant); + } +}; + template <typename T> struct ProtoHelper {}; @@ -290,6 +323,26 @@ struct ProtoHelper<ResourceHandle> { }; template <> +struct ProtoHelper<Variant> { + static protobuf::RepeatedPtrField<VariantTensorDataProto>::const_iterator + Begin(const TensorProto& proto) { + return proto.variant_val().begin(); + } + static size_t NumElements(const TensorProto& proto) { + return proto.variant_val().size(); + } + static void Fill(const Variant* data, size_t n, TensorProto* proto) { + auto* variant_values = proto->mutable_variant_val(); + variant_values->Clear(); + for (size_t i = 0; i < n; ++i) { + VariantTensorData tmp; + data[i].Encode(&tmp); + tmp.ToProto(variant_values->Add()); + } + } +}; + +template <> struct ProtoHelper<complex64> { typedef Helper<float>::RepeatedFieldType FieldType; static const complex64* Begin(const TensorProto& proto) { @@ -421,6 +474,30 @@ TensorBuffer* FromProtoField(Allocator* a, const TensorProto& in, int64 n) { return buf; } +template <> +TensorBuffer* FromProtoField<Variant>(Allocator* a, const TensorProto& in, + int64 n) { + CHECK_GT(n, 0); + Buffer<Variant>* buf = new Buffer<Variant>(a, n); + Variant* data = buf->template base<Variant>(); + if (data == nullptr) { + buf->Unref(); + return nullptr; + } + const int64 in_n = ProtoHelper<Variant>::NumElements(in); + if (in_n <= 0) { + std::fill_n(data, n, Variant()); + } else { + for (int64 i = 0; i < in_n; ++i) { + data[i] = in.variant_val(i); + } + for (int64 i = in_n; i < n; ++i) { + data[i] = Variant(); + } + } + return buf; +} + // fp16 is opaque to the protobuf, so we deserialize these identical to uint16 // but with data stored in half_val instead of int_val (ie., we don't use // ProtoHelper<uint16>). @@ -571,6 +648,7 @@ bool Tensor::RefCountIsOne() const { CASE(bfloat16, SINGLE_ARG(STMTS)) \ CASE(Eigen::half, SINGLE_ARG(STMTS)) \ CASE(ResourceHandle, SINGLE_ARG(STMTS)) \ + CASE(Variant, SINGLE_ARG(STMTS)) \ case DT_INVALID: \ INVALID; \ break; \ diff --git a/tensorflow/core/framework/tensor.proto b/tensorflow/core/framework/tensor.proto index 98ecef225a..7e4af7a645 100644 --- a/tensorflow/core/framework/tensor.proto +++ b/tensorflow/core/framework/tensor.proto @@ -72,4 +72,17 @@ message TensorProto { // DT_RESOURCE repeated ResourceHandleProto resource_handle_val = 14; + + // DT_VARIANT + repeated VariantTensorDataProto variant_val = 15; }; + +// Protocol buffer representing the serialization format of DT_VARIANT tensors. +message VariantTensorDataProto { + // Name of the type of objects being serialized. + string type_name = 1; + // Portions of the object that are not Tensors. + bytes metadata = 2; + // Tensors contained within objects being serialized. + repeated TensorProto tensors = 3; +} diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 6c9c803af6..9aaf00853d 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" @@ -37,6 +38,35 @@ inline bool operator==(const ResourceHandle& a, const ResourceHandle& b) { a.maybe_type_name() == b.maybe_type_name(); } +inline bool operator==(const Variant& a, const Variant& b) { + if (a.is_empty()) { + return b.is_empty(); + } + + if (a.TypeId() != b.TypeId()) return false; + if (a.TypeName() != b.TypeName()) return false; + + VariantTensorData a_data, b_data; + a.Encode(&a_data); + b.Encode(&b_data); + + if (a_data.metadata != b_data.metadata) return false; + + if (a_data.tensors.size() != b_data.tensors.size()) return false; + + for (int i = 0; i < a_data.tensors.size(); ++i) { + TensorProto a_proto, b_proto; + a_data.tensors[i].AsProtoTensorContent(&a_proto); + b_data.tensors[i].AsProtoTensorContent(&b_proto); + string a_str, b_str; + a_proto.SerializeToString(&a_str); + b_proto.SerializeToString(&b_str); + if (a_str != b_str) return false; + } + + return true; +} + TEST(TensorTest, Default) { Tensor t; EXPECT_EQ(t.dtype(), DT_FLOAT); @@ -159,6 +189,74 @@ TEST(Tensor_ResourceHandle, Simple) { TestCopies<ResourceHandle>(t); } +TEST(Tensor_Variant, Simple) { + Tensor t(DT_VARIANT, TensorShape({})); + Tensor value(DT_FLOAT, TensorShape({})); + value.flat<float>()(0) = 42.0f; + t.flat<Variant>()(0) = value; + // All the tests in TestCopies except the ones that serialize and deserialize + // the tensor. The consumer of a serialized Variant Tensor should know what + // type is stored in the Tensor, so not testing the generic + // serialize/deserialize case here. + { + LOG(INFO) << "CopyFrom()"; + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.CopyFrom(t, t.shape())); + test::ExpectTensorEqual<Variant>(t, t2); + } + { + LOG(INFO) << "operator=()"; + Tensor t2(t.dtype()); + t2 = t; + test::ExpectTensorEqual<Variant>(t, t2); + } + { + LOG(INFO) << "deep copy"; + Tensor t2(t.dtype(), t.shape()); + t2.flat<Variant>() = t.flat<Variant>(); + test::ExpectTensorEqual<Variant>(t, t2); + } + { + LOG(INFO) << "AsTensor"; + gtl::ArraySlice<Variant> values(t.flat<Variant>().data(), t.NumElements()); + Tensor t2 = test::AsTensor(values, t.shape()); + test::ExpectTensorEqual<Variant>(t, t2); + } + { + LOG(INFO) << "Move constructor"; + Tensor t2 = t; + Tensor t3(std::move(t2)); + test::ExpectTensorEqual<Variant>(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } + { + LOG(INFO) << "Move assignment"; + Tensor t2 = t; + Tensor t3 = std::move(t2); + Tensor* t4 = &t3; + *t4 = std::move(t3); + test::ExpectTensorEqual<Variant>(t, t3); + EXPECT_TRUE(t3.IsInitialized()); + EXPECT_FALSE(t2.IsInitialized()); + } +} + +TEST(Tensor_Variant, Marshal) { + Tensor t(DT_VARIANT, TensorShape({})); + + Tensor internal(DT_FLOAT, TensorShape({})); + internal.flat<float>()(0) = 42.0f; + t.flat<Variant>()(0) = internal; + + LOG(INFO) << "AsProtoField()"; + TensorProto proto; + t.AsProtoField(&proto); + + Tensor t2(t.dtype()); + EXPECT_TRUE(t2.FromProto(proto)); +} + TEST(Tensor_UInt16, Simple) { Tensor t(DT_UINT16, TensorShape({2, 2})); EXPECT_TRUE(t.shape().IsSameSize(TensorShape({2, 2}))); diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc index 28b5e8be11..39dd5b435e 100644 --- a/tensorflow/core/framework/types.cc +++ b/tensorflow/core/framework/types.cc @@ -87,6 +87,8 @@ string DataTypeString(DataType dtype) { return "half"; case DT_RESOURCE: return "resource"; + case DT_VARIANT: + return "variant"; default: LOG(ERROR) << "Unrecognized DataType enum value " << dtype; return strings::StrCat("unknown dtype enum (", dtype, ")"); @@ -165,6 +167,9 @@ bool DataTypeFromString(StringPiece sp, DataType* dt) { } else if (sp == "resource") { *dt = DT_RESOURCE; return true; + } else if (sp == "variant") { + *dt = DT_VARIANT; + return true; } return false; } @@ -186,7 +191,7 @@ DataTypeVector AllTypes() { return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128, DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16, - DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE}; + DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT}; } #if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION) diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 889056647c..9127750d68 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -181,6 +182,7 @@ MATCH_TYPE_AND_ENUM(qint32, DT_QINT32); MATCH_TYPE_AND_ENUM(bfloat16, DT_BFLOAT16); MATCH_TYPE_AND_ENUM(Eigen::half, DT_HALF); MATCH_TYPE_AND_ENUM(ResourceHandle, DT_RESOURCE); +MATCH_TYPE_AND_ENUM(Variant, DT_VARIANT); #undef MATCH_TYPE_AND_ENUM diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index b80e2b31dc..1beb2a1aa2 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -34,6 +34,7 @@ enum DataType { DT_COMPLEX128 = 18; // Double-precision complex DT_HALF = 19; DT_RESOURCE = 20; + DT_VARIANT = 21; // Arbitrary C++ data types // TODO(josh11b): DT_GENERIC_PROTO = ??; // TODO(jeff,josh11b): DT_UINT64? DT_UINT32? @@ -60,5 +61,6 @@ enum DataType { DT_COMPLEX128_REF = 118; DT_HALF_REF = 119; DT_RESOURCE_REF = 120; + DT_VARIANT_REF = 121; } // LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc new file mode 100644 index 0000000000..6a11f86c23 --- /dev/null +++ b/tensorflow/core/framework/variant.cc @@ -0,0 +1,87 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { + +template <> +void* Variant::get() { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +template <> +const void* Variant::get() const { + if (is_empty()) { + return nullptr; + } + return value_->RawPtr(); +} + +void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { + proto->set_type_name(type_name); + proto->set_metadata(metadata); + proto->clear_tensors(); + for (int i = 0; i < tensors.size(); ++i) { + tensors[i].AsProtoField(proto->mutable_tensors()->Add()); + } +} + +bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { + type_name = proto.type_name(); + metadata = proto.metadata(); + tensors.clear(); + for (int i = 0; i < proto.tensors_size(); ++i) { + Tensor tmp; + if (!tmp.FromProto(proto.tensors(i))) return false; + tensors.push_back(tmp); + } + return true; +} + +template <> +string TypeNameVariant(VariantTensorDataProto&& value) { + return value.GetTypeName(); +} + +template <> +void EncodeVariant(VariantTensorDataProto&& value, VariantTensorData* data) { + data->FromProto(value); +} + +template <> +bool DecodeVariant(const VariantTensorData& data, + VariantTensorDataProto* value) { + data.ToProto(value); + return true; +} + +template <> +void EncodeVariant(VariantTensorDataProto&& value, string* buf) { + value.SerializeToString(buf); +} + +template <> +bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { + return value->ParseFromString(buf); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h new file mode 100644 index 0000000000..df6d903143 --- /dev/null +++ b/tensorflow/core/framework/variant.h @@ -0,0 +1,287 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_ +#define TENSORFLOW_FRAMEWORK_VARIANT_H_ + +#include <functional> +#include <iostream> +#include <memory> +#include <type_traits> +#include <unordered_map> +#include <utility> + +#include "tensorflow/core/framework/type_index.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +struct VariantTensorData; + +template <typename T> +string TypeNameVariant(T&& value); + +template <typename T> +void EncodeVariant(T&& value, VariantTensorData* data); + +template <typename T> +bool DecodeVariant(const VariantTensorData& data, T* value); + +template <typename T> +void EncodeVariant(T&& value, string* buf); + +template <typename T> +bool DecodeVariant(const string& buf, T* value); + +// This is an implementation of a type-erased container that can store an +// object of any type. The implementation is very similar to std::any, but has +// restrictions on the types of objects that can be stored, and eschews some of +// the fancier constructors available for std::any. An object of +// tensorflow::Variant is intended to be used as the value that will be stored +// in a tensorflow::Tensor object when its type is DT_VARIANT. +// +// tensorflow::Variant can store an object of a class that satisfies the +// following constraints: +// +// * The class is CopyConstructible. +// * The class has a default constructor. +// * It's either a protocol buffer, a tensorflow::Tensor, or defines the +// following functions: +// +// string TypeName() const; +// void Encode(VariantTensorData* data) const; +// void Decode(const VariantTensorData& data); +// +// Simple POD types can elide the Encode/Decode functions, they are provided by +// helper methods. +// Here are some typical usage patterns: +// +// Variant x = 10; +// EXPECT_EQ(*x.get<int>(), 10); +// +// Tensor t(DT_FLOAT, TensorShape({})); +// t.flat<float>()(0) = 42.0f; +// Variant x = t; +// EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f); +// +// Accessing the stored object: +// +// The get<T> function is the main mechanism to access the object stored in the +// contained. It is type-safe, that is, calling get<T> when the stored object's +// type is not T, returns a nullptr. A raw pointer to the stored object can be +// obtained by calling get<void>(). +// +// Serializing/deserializing Variant object: +// +// The Variant class delegates serializing and deserializing operations to the +// contained object. Helper functions to do these operations are provided for +// POD data types, tensorflow::Tensor, and protocol buffer objects. However, +// other classes have to provide Encode/Decode functions to do handle +// serialization. +// +// Objects stored in a Variant object often contain references to other +// tensorflow::Tensors of primitive types (Eg., a list of tensorflow::Tensors). +// To efficiently support those use cases, a structure is imposed on the +// serialization format. Namely, classes should serialize their contents in to a +// VariantTensorData object: +// +// struct VariantTensorData { +// string type_name; +// string metadata; +// std::vector<Tensor> tensors; +// }; +// +// Objects with references to other Tensors can simply store those tensors in +// the `tensors` field, and serialize other metadata content in to the +// `metadata` field. +// +// Serialization example: +// +// Foo f = Foo {...}; +// Variant x = f; +// string serialized_f; +// x.Encode(&serialized_f); +// +// Variant y = Foo(); // default constructed Foo. +// y.Decode(&serialized_f); +// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>()); +// +class Variant { + public: + constexpr Variant() noexcept = default; + + Variant(const Variant& other) + : value_(other.is_empty() ? std::unique_ptr<ValueInterface>() + : other.value_->Clone()) {} + + Variant(Variant&& other) noexcept = default; + + // Make sure that the type is CopyConstructible and not a tensorflow::Variant + // object itself. We want the copy constructor to be chosen for the + // tensorflow::Variant case. + template <typename T, typename VT = typename std::decay<T>::type, + typename std::enable_if<!std::is_same<Variant, VT>::value && + std::is_copy_constructible<VT>::value, + void>::type* = nullptr> + Variant(T&& value) // NOLINT + : value_(new Value<VT>(in_place, std::forward<T>(value))) {} + + Variant& operator=(const Variant& rhs) { + Variant(rhs).swap(*this); + return *this; + } + + Variant& operator=(Variant&& rhs) noexcept { + Variant(std::move(rhs)).swap(*this); + return *this; + } + + bool is_empty() const { return value_ == nullptr; } + + void clear() noexcept { value_.reset(); } + + void swap(Variant& other) noexcept { value_.swap(other.value_); } + + TypeIndex TypeId() const { + const TypeIndex VoidTypeIndex = MakeTypeIndex<void>(); + if (is_empty()) { + return VoidTypeIndex; + } + return value_->TypeId(); + } + + template <typename T> + T* get() { + const TypeIndex TTypeIndex = MakeTypeIndex<T>(); + if (is_empty() || (TTypeIndex != TypeId())) { + return nullptr; + } + return std::addressof(static_cast<Variant::Value<T>*>(value_.get())->value); + } + + template <typename T> + const T* get() const { + const TypeIndex TTypeIndex = MakeTypeIndex<T>(); + if (is_empty() || (TTypeIndex != TypeId())) { + return nullptr; + } + return std::addressof( + static_cast<const Variant::Value<T>*>(value_.get())->value); + } + + string TypeName() const { + if (is_empty()) { + return ""; + } + return value_->TypeName(); + } + + // Serialize the contents of the stored object into `data`. + void Encode(VariantTensorData* data) const { + if (!is_empty()) { + value_->Encode(data); + } + } + + // Deserialize `data` and update the stored object. + bool Decode(const VariantTensorData& data) { + if (!is_empty()) { + return value_->Decode(data); + } + return true; + } + + // Helper methods to directly serialize/deserialize from strings. + void Encode(string* buf) const { + if (!is_empty()) { + value_->Encode(buf); + } + } + bool Decode(const string& buf) { + if (!is_empty()) { + return value_->Decode(buf); + } + return true; + } + + private: + struct in_place_t {}; + static constexpr in_place_t in_place{}; + + struct ValueInterface { + virtual ~ValueInterface() = default; + virtual TypeIndex TypeId() const = 0; + virtual void* RawPtr() = 0; + virtual const void* RawPtr() const = 0; + virtual std::unique_ptr<ValueInterface> Clone() const = 0; + virtual string TypeName() const = 0; + virtual void Encode(VariantTensorData* data) const = 0; + virtual bool Decode(const VariantTensorData& data) = 0; + virtual void Encode(string* buf) const = 0; + virtual bool Decode(const string& data) = 0; + }; + + template <typename T> + struct Value : ValueInterface { + template <class... Args> + explicit Value(in_place_t /*tag*/, Args&&... args) + : value(std::forward<Args>(args)...) {} + + TypeIndex TypeId() const override { + const TypeIndex value_type_index = + MakeTypeIndex<typename std::decay<T>::type>(); + return value_type_index; + } + + void* RawPtr() override { return &value; } + + const void* RawPtr() const override { return &value; } + + std::unique_ptr<ValueInterface> Clone() const override { + return std::unique_ptr<ValueInterface>(new Value(in_place, value)); + } + + string TypeName() const override { return TypeNameVariant(value); } + + void Encode(VariantTensorData* data) const override { + EncodeVariant(value, data); + } + + bool Decode(const VariantTensorData& data) override { + return DecodeVariant(data, &value); + } + + void Encode(string* buf) const override { EncodeVariant(value, buf); } + + bool Decode(const string& buf) override { + return DecodeVariant(buf, &value); + } + + T value; + }; + + std::unique_ptr<ValueInterface> value_; +}; + +template <> +void* Variant::get(); + +template <> +const void* Variant::get() const; + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_H_ diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h new file mode 100644 index 0000000000..09309aa7e5 --- /dev/null +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -0,0 +1,220 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ +#define TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ + +#include <iostream> +#include <type_traits> +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { + +// The serialization format for Variant objects. Objects with references to +// other Tensors can simply store those tensors in the `tensors` field, and +// serialize other metadata content in to the `metadata` field. Objects can +// optionally set the `type_name` for type-checking before deserializing an +// object. +struct VariantTensorData { + string type_name; + string metadata; + std::vector<Tensor> tensors; + void ToProto(VariantTensorDataProto* proto) const; + bool FromProto(const VariantTensorDataProto& proto); +}; + +// Type used for tag-dispatch of the Encode/Decode Variant implementations. This +// template can determine whether the first type parameter `T` is one of the +// following: +// +// * A POD type (TypeResolver<T, true>) +// * A tensorflow::Tensor (TypeResolver<T, false, true>) +// * A protocol buffer (TypeResolver<T, false, false, true>) +// * None of the above (TypeResolver<T, false, false, false>) +// +template <typename T, bool = std::is_pod<typename std::decay<T>::type>::value, + bool = std::is_same<typename std::decay<T>::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of<protobuf::MessageLite, + typename std::decay<T>::type>::value> +struct TypeResolver {}; + +// Specialization for POD type +template <typename T> +void EncodeVariantImpl(T&& value, TypeResolver<T, true /* is_pod */>, + VariantTensorData* data) { + data->metadata.assign(reinterpret_cast<const char*>(&value), sizeof(value)); +} + +// Specialization for tensorflow::Tensor +template <typename T> +void EncodeVariantImpl(T&& value, + TypeResolver<T, false /* is_pod */, true /* Tensor */>, + VariantTensorData* data) { + data->tensors.clear(); + data->tensors.push_back(value); +} + +// Specialization for protobuf +template <typename T> +void EncodeVariantImpl(T&& value, + TypeResolver<T, false /* is_pod */, false /* Tensor */, + true /* protobuf */>, + VariantTensorData* data) { + value.SerializeToString(&data->metadata); +} + +// Specialization for other types +template <typename T> +void EncodeVariantImpl(T&& value, + TypeResolver<T, false /* is_pod */, false /* Tensor */, + false /* protobuf */>, + VariantTensorData* data) { + value.Encode(data); +} + +// Specialization for POD type +template <typename T> +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver<T, true /* is_pod */>, T* value) { + std::copy_n(data.metadata.data(), sizeof(*value), + reinterpret_cast<char*>(value)); + return true; +} + +// Specialization for tensorflow::Tensor +template <typename T> +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver<T, false /* is_pod */, true /* Tensor */>, + T* value) { + *value = data.tensors[0]; + return true; +} + +// Specialization for protobuf +template <typename T> +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver<T, false /* is_pod */, false /* Tensor */, + true /* protobuf */>, + T* value) { + return value->ParseFromString(data.metadata); +} + +// Specialization for other types +template <typename T> +bool DecodeVariantImpl(const VariantTensorData& data, + TypeResolver<T, false /* is_pod */, false /* Tensor */, + false /* protobuf */>, + T* value) { + return value->Decode(data); +} + +template <typename C, typename = void> +struct has_type_name : std::false_type {}; + +template <typename C> +struct has_type_name< + C, typename std::enable_if<std::is_same< + decltype(std::declval<C>().TypeName()), string>::value>::type> + : std::true_type {}; + +template <typename T, bool = has_type_name<typename std::decay<T>::type>::value, + bool = std::is_same<typename std::decay<T>::type, + ::tensorflow::Tensor>::value, + bool = std::is_base_of<protobuf::MessageLite, + typename std::decay<T>::type>::value> +struct TypeNameResolver {}; + +template <typename T> +string TypeNameVariantImpl(T&& value, + TypeNameResolver<T, true /* has_type_name */>) { + return value.TypeName(); +} + +template <typename T> +string TypeNameVariantImpl( + T&& value, + TypeNameResolver<T, false /* has_type_name */, true /* Tensor */>) { + return "tensorflow::Tensor"; +} + +template <typename T> +string TypeNameVariantImpl( + T&& value, TypeNameResolver<T, false /* has_type_name */, + false /* Tensor */, true /* protobuf */>) { + return value.GetTypeName(); +} + +template <typename T> +string TypeNameVariantImpl( + T&& value, TypeNameResolver<T, false /* has_type_name */, + false /* Tensor */, false /* protobuf */>) { + return value.TypeName(); +} + +template <typename T> +string TypeNameVariant(T&& value) { + return TypeNameVariantImpl(std::forward<T>(value), TypeNameResolver<T>()); +} + +template <typename T> +void EncodeVariant(T&& value, VariantTensorData* data) { + EncodeVariantImpl(std::forward<T>(value), TypeResolver<T>(), data); +} + +template <typename T> +bool DecodeVariant(const VariantTensorData& data, T* value) { + return DecodeVariantImpl(data, TypeResolver<T>(), value); +} + +template <typename T> +void EncodeVariant(T&& value, string* buf) { + VariantTensorData data; + EncodeVariantImpl(std::forward<T>(value), TypeResolver<T>(), &data); + VariantTensorDataProto proto; + data.ToProto(&proto); + proto.SerializeToString(buf); +} + +template <typename T> +bool DecodeVariant(const string& buf, T* value) { + VariantTensorDataProto proto; + if (!proto.ParseFromString(buf)) return false; + VariantTensorData data; + if (!data.FromProto(proto)) return false; + if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false; + return true; +} + +// Specializations for VariantTensorDataProto +template <> +string TypeNameVariant(VariantTensorDataProto&& value); +template <> +void EncodeVariant(VariantTensorDataProto&& value, VariantTensorData* data); +template <> +bool DecodeVariant(const VariantTensorData& data, + VariantTensorDataProto* value); +template <> +void EncodeVariant(VariantTensorDataProto&& value, string* buf); +template <> +bool DecodeVariant(const string& buf, VariantTensorDataProto* value); + +} // end namespace tensorflow + +#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_ diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc new file mode 100644 index 0000000000..c7ffdd28f4 --- /dev/null +++ b/tensorflow/core/framework/variant_test.cc @@ -0,0 +1,249 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <vector> + +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +template <typename T> +struct Wrapper { + T value; + string TypeName() const { return "POD"; } +}; + +using Int = Wrapper<int>; +using Float = Wrapper<float>; + +} // end namespace + +TEST(VariantTest, Basic) { + Variant x; + EXPECT_EQ(x.get<void>(), nullptr); + + x = Int{42}; + + EXPECT_NE(x.get<void>(), nullptr); + EXPECT_NE(x.get<Int>(), nullptr); + EXPECT_EQ(x.get<Int>()->value, 42); + EXPECT_EQ(x.TypeName(), "POD"); +} + +TEST(VariantTest, ConstGet) { + Variant x; + EXPECT_EQ(x.get<void>(), nullptr); + + x = Int{42}; + + const Variant y = x; + + EXPECT_NE(y.get<void>(), nullptr); + EXPECT_NE(y.get<Int>(), nullptr); + EXPECT_EQ(y.get<Int>()->value, 42); +} + +TEST(VariantTest, Clear) { + Variant x; + EXPECT_EQ(x.get<void>(), nullptr); + + x = Int{42}; + + EXPECT_NE(x.get<void>(), nullptr); + EXPECT_NE(x.get<Int>(), nullptr); + EXPECT_EQ(x.get<Int>()->value, 42); + + x.clear(); + EXPECT_EQ(x.get<void>(), nullptr); +} + +TEST(VariantTest, Tensor) { + Variant x; + Tensor t(DT_FLOAT, {}); + t.flat<float>()(0) = 42.0f; + x = t; + + EXPECT_NE(x.get<Tensor>(), nullptr); + EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 42.0f); + x.get<Tensor>()->flat<float>()(0) += 1.0f; + EXPECT_EQ(x.get<Tensor>()->flat<float>()(0), 43.0f); + EXPECT_EQ(x.TypeName(), "tensorflow::Tensor"); +} + +TEST(VariantTest, TensorProto) { + Variant x; + TensorProto t; + t.set_dtype(DT_FLOAT); + t.mutable_tensor_shape()->set_unknown_rank(true); + x = t; + + EXPECT_EQ(x.TypeName(), "tensorflow.TensorProto"); + EXPECT_NE(x.get<TensorProto>(), nullptr); + EXPECT_EQ(x.get<TensorProto>()->dtype(), DT_FLOAT); + EXPECT_EQ(x.get<TensorProto>()->tensor_shape().unknown_rank(), true); +} + +TEST(VariantTest, CopyValue) { + Variant x, y; + x = Int{10}; + y = x; + + EXPECT_EQ(x.get<Int>()->value, 10); + EXPECT_EQ(x.get<Int>()->value, y.get<Int>()->value); +} + +TEST(VariantTest, MoveValue) { + Variant x; + x = []() -> Variant { + Variant y; + y = Int{10}; + return y; + }(); + EXPECT_EQ(x.get<Int>()->value, 10); +} + +TEST(VariantTest, TypeMismatch) { + Variant x; + x = Int{10}; + EXPECT_EQ(x.get<float>(), nullptr); + EXPECT_EQ(x.get<int>(), nullptr); + EXPECT_NE(x.get<Int>(), nullptr); +} + +struct TensorList { + void Encode(VariantTensorData* data) const { data->tensors = vec; } + + bool Decode(const VariantTensorData& data) { + vec = data.tensors; + return true; + } + + string TypeName() const { return "TensorList"; } + + std::vector<Tensor> vec; +}; + +TEST(VariantTest, TensorListTest) { + Variant x; + + TensorList vec; + for (int i = 0; i < 4; ++i) { + Tensor elem(DT_INT32, {1}); + elem.flat<int>()(0) = i; + vec.vec.push_back(elem); + } + + for (int i = 0; i < 4; ++i) { + Tensor elem(DT_FLOAT, {1}); + elem.flat<float>()(0) = 2 * i; + vec.vec.push_back(elem); + } + + x = vec; + + EXPECT_EQ(x.TypeName(), "TensorList"); + const TensorList& stored_vec = *x.get<TensorList>(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(stored_vec.vec[i].flat<int>()(0), i); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(stored_vec.vec[i + 4].flat<float>()(0), 2 * i); + } + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y = TensorList(); + y.Decode(serialized); + + const TensorList& decoded_vec = *x.get<TensorList>(); + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(decoded_vec.vec[i].flat<int>()(0), i); + } + for (int i = 0; i < 4; ++i) { + EXPECT_EQ(decoded_vec.vec[i + 4].flat<float>()(0), 2 * i); + } +} + +TEST(VariantTest, VariantArray) { + Variant x[2]; + x[0] = Int{2}; + x[1] = Float{2.0f}; + + EXPECT_EQ(x[0].get<Int>()->value, 2); + EXPECT_EQ(x[1].get<Float>()->value, 2.0f); +} + +TEST(VariantTest, PodUpdate) { + struct Pod { + int x; + float y; + + string TypeName() const { return "POD"; } + }; + + Variant x = Pod{10, 20.f}; + EXPECT_NE(x.get<Pod>(), nullptr); + EXPECT_EQ(x.TypeName(), "POD"); + + x.get<Pod>()->x += x.get<Pod>()->y; + EXPECT_EQ(x.get<Pod>()->x, 30); +} + +TEST(VariantTest, EncodeDecodePod) { + struct Pod { + int x; + float y; + + string TypeName() const { return "POD"; } + }; + + Variant x; + Pod p{10, 20.0f}; + x = p; + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y; + y = Pod(); + y.Decode(serialized); + + EXPECT_EQ(p.x, y.get<Pod>()->x); + EXPECT_EQ(p.y, y.get<Pod>()->y); +} + +TEST(VariantTest, EncodeDecodeTensor) { + Variant x; + Tensor t(DT_INT32, {}); + t.flat<int>()(0) = 42; + x = t; + + VariantTensorData serialized; + x.Encode(&serialized); + + Variant y = Tensor(); + y.Decode(serialized); + EXPECT_EQ(x.get<Tensor>()->flat<int>()(0), y.get<Tensor>()->flat<int>()(0)); +} + +} // end namespace tensorflow diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index d12ad8a04f..5db0fe423c 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -130,6 +130,14 @@ def tf_additional_lib_srcs(exclude = []): ], exclude = exclude), }) +# pylint: disable=unused-argument +def tf_additional_framework_hdrs(exclude = []): + return [] + +def tf_additional_framework_srcs(exclude = []): + return [] +# pylint: enable=unused-argument + def tf_additional_minimal_lib_srcs(): return [ "platform/default/integral_types.h", diff --git a/tensorflow/core/platform/variant_coding.cc b/tensorflow/core/platform/variant_coding.cc new file mode 100644 index 0000000000..4bcde4f581 --- /dev/null +++ b/tensorflow/core/platform/variant_coding.cc @@ -0,0 +1,63 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/variant_coding.h" + +#include <vector> +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/lib/core/coding.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace port { + +void EncodeVariantList(const Variant* variant_array, int64 n, string* out) { + out->clear(); + string rest; + for (int i = 0; i < n; ++i) { + string s; + variant_array[i].Encode(&s); + core::PutVarint32(out, s.length()); + strings::StrAppend(&rest, s); + } + strings::StrAppend(out, rest); +} + +bool DecodeVariantList(const string& in, Variant* variant_array, int64 n) { + std::vector<uint32> sizes(n); + StringPiece reader(in); + int64 total = 0; + for (auto& size : sizes) { + if (!core::GetVarint32(&reader, &size)) return false; + total += size; + } + if (total != static_cast<int64>(reader.size())) { + return false; + } + + for (int i = 0; i < n; ++i) { + if (variant_array[i].is_empty()) { + variant_array[i] = VariantTensorDataProto(); + } + string str(reader.data(), sizes[i]); + if (!variant_array[i].Decode(str)) return false; + reader.remove_prefix(sizes[i]); + } + return true; +} + +} // end namespace port +} // end namespace tensorflow diff --git a/tensorflow/core/platform/variant_coding.h b/tensorflow/core/platform/variant_coding.h new file mode 100644 index 0000000000..34c2481149 --- /dev/null +++ b/tensorflow/core/platform/variant_coding.h @@ -0,0 +1,39 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_PLATFORM_VARIANT_CODING_H_ +#define TENSORFLOW_PLATFORM_VARIANT_CODING_H_ + +#include "tensorflow/core/framework/variant.h" + +#ifdef PLATFORM_GOOGLE +#include "tensorflow/core/platform/google/variant_cord_coding.h" +#endif + +namespace tensorflow { +namespace port { + +// Encodes an array of Variant objects in to the given string. +// `variant_array` is assumed to point to an array of `n` Variant objects. +void EncodeVariantList(const Variant* variant_array, int64 n, string* out); + +// Decodes an array of Variant objects from the given string. +// `variant_array` is assumed to point to an array of `n` Variant objects. +bool DecodeVariantList(const string& in, Variant* variant_array, int64 n); + +} // end namespace port +} // end namespace tensorflow + +#endif // TENSORFLOW_PLATFORM_VARIANT_CODING_H_ diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 4a60c736b5..8b8909a6f8 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -227,6 +227,7 @@ var types = []struct { {reflect.TypeOf(uint16(0)), C.TF_UINT16}, {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128}, // TODO(apassos): support DT_RESOURCE representation in go. + // TODO(keveman): support DT_VARIANT representation in go. } // shapeAndDataTypeOf returns the data type and shape of the Tensor diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 827e4c53ee..2fd8fc8688 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -208,6 +208,7 @@ _allowed_symbols.extend([ 'uint16', 'uint8', 'resource', + 'variant', ]) # Export modules and constants. diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py index 00bfae213a..43535a593e 100644 --- a/tensorflow/python/framework/dtypes.py +++ b/tensorflow/python/framework/dtypes.py @@ -48,6 +48,7 @@ class DType(object): * `tf.quint16`: Quantized 16-bit unsigned integer. * `tf.qint32`: Quantized 32-bit signed integer. * `tf.resource`: Handle to a mutable resource. + * `tf.variant`: Values of arbitrary types. In addition, variants of these types with the `_ref` suffix are defined for reference-typed tensors. @@ -113,8 +114,11 @@ class DType(object): @property def is_numpy_compatible(self): - return (self._type_enum != types_pb2.DT_RESOURCE and - self._type_enum != types_pb2.DT_RESOURCE_REF) + numpy_incompatible = [types_pb2.DT_VARIANT, + types_pb2.DT_VARIANT_REF, + types_pb2.DT_RESOURCE, + types_pb2.DT_RESOURCE_REF] + return self._type_enum not in numpy_incompatible @property def as_numpy_dtype(self): @@ -284,7 +288,8 @@ class DType(object): @property def size(self): - if self._type_enum == types_pb2.DT_RESOURCE: + if (self._type_enum == types_pb2.DT_VARIANT or + self._type_enum == types_pb2.DT_RESOURCE): return 1 return np.dtype(self.as_numpy_dtype).itemsize @@ -304,6 +309,7 @@ dtype_range = {np.bool_: (False, True), # Define standard wrappers for the types_pb2.DataType enum. resource = DType(types_pb2.DT_RESOURCE) +variant = DType(types_pb2.DT_VARIANT) float16 = DType(types_pb2.DT_HALF) half = float16 float32 = DType(types_pb2.DT_FLOAT) @@ -325,6 +331,7 @@ qint16 = DType(types_pb2.DT_QINT16) quint16 = DType(types_pb2.DT_QUINT16) qint32 = DType(types_pb2.DT_QINT32) resource_ref = DType(types_pb2.DT_RESOURCE_REF) +variant_ref = DType(types_pb2.DT_VARIANT_REF) bfloat16 = DType(types_pb2.DT_BFLOAT16) float16_ref = DType(types_pb2.DT_HALF_REF) half_ref = float16_ref @@ -372,6 +379,7 @@ _INTERN_TABLE = { types_pb2.DT_QINT32: qint32, types_pb2.DT_BFLOAT16: bfloat16, types_pb2.DT_RESOURCE: resource, + types_pb2.DT_VARIANT: variant, types_pb2.DT_HALF_REF: float16_ref, types_pb2.DT_FLOAT_REF: float32_ref, types_pb2.DT_DOUBLE_REF: float64_ref, @@ -392,6 +400,7 @@ _INTERN_TABLE = { types_pb2.DT_QINT32_REF: qint32_ref, types_pb2.DT_BFLOAT16_REF: bfloat16_ref, types_pb2.DT_RESOURCE_REF: resource_ref, + types_pb2.DT_VARIANT_REF: variant_ref, } @@ -417,6 +426,7 @@ _TYPE_TO_STRING = { types_pb2.DT_QINT32: "qint32", types_pb2.DT_BFLOAT16: "bfloat16", types_pb2.DT_RESOURCE: "resource", + types_pb2.DT_VARIANT: "variant", types_pb2.DT_HALF_REF: "float16_ref", types_pb2.DT_FLOAT_REF: "float32_ref", types_pb2.DT_DOUBLE_REF: "float64_ref", @@ -437,6 +447,7 @@ _TYPE_TO_STRING = { types_pb2.DT_QINT32_REF: "qint32_ref", types_pb2.DT_BFLOAT16_REF: "bfloat16_ref", types_pb2.DT_RESOURCE_REF: "resource_ref", + types_pb2.DT_VARIANT_REF: "variant_ref", } _STRING_TO_TF = {value: _INTERN_TABLE[key] for key, value in _TYPE_TO_STRING.items()} diff --git a/tensorflow/python/framework/dtypes_test.py b/tensorflow/python/framework/dtypes_test.py index d5a1e61a6f..1e84f1b656 100644 --- a/tensorflow/python/framework/dtypes_test.py +++ b/tensorflow/python/framework/dtypes_test.py @@ -27,9 +27,12 @@ from tensorflow.python.platform import googletest def _is_numeric_dtype_enum(datatype_enum): - return (datatype_enum != types_pb2.DT_INVALID and - datatype_enum != types_pb2.DT_RESOURCE and - datatype_enum != types_pb2.DT_RESOURCE_REF) + non_numeric_dtypes = [types_pb2.DT_VARIANT, + types_pb2.DT_VARIANT_REF, + types_pb2.DT_INVALID, + types_pb2.DT_RESOURCE, + types_pb2.DT_RESOURCE_REF] + return datatype_enum not in non_numeric_dtypes class TypesTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 6a035aba63..314449bb73 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -493,6 +493,10 @@ tf_module { mtype: "<type \'type\'>" } member { + name: "variant" + mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>" + } + member { name: "zeros_initializer" mtype: "<type \'type\'>" } |