diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-09-11 10:41:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 10:51:01 -0700 |
commit | 36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 (patch) | |
tree | 4f1671f78f5971b02dc2af66f57eabbf01005112 /tensorflow/core/framework | |
parent | 36d7b12357df667dcd427c070e21779ed83f4ec9 (diff) |
[TF] Variant improvements.
1. Change Variant Decode to accept VariantTensorData (non-ref).
This should allow some optimization in the future.
In the meantime it means removing the variant.h include from tensor.h, since
variant_encode_decode.h now relies on tensor.h and variant.h now relies on that.
It also means we found a bunch of places where tensor.proto.h, variant.h, and
mutex.h were being imported through tensor.h (along with a bunch of other crap);
so now we directly import them in order to compile.
2. Move Variant registry to use TypeIndex instead of a TypeName string; this should
speed up registry lookups.
PiperOrigin-RevId: 212478896
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/allocator.cc | 9 | ||||
-rw-r--r-- | tensorflow/core/framework/allocator.h | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/allocator_registry.h | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/attr_value_util_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/tensor_util.h | 1 | ||||
-rw-r--r-- | tensorflow/core/framework/types.h | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/variant.cc | 25 | ||||
-rw-r--r-- | tensorflow/core/framework/variant.h | 60 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_encode_decode.h | 32 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_op_copy_test.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.cc | 85 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.h | 216 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_op_registry_test.cc | 96 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_tensor_data.cc | 22 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_tensor_data.h | 10 | ||||
-rw-r--r-- | tensorflow/core/framework/variant_test.cc | 15 |
18 files changed, 302 insertions, 295 deletions
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 888ed0c57b..2a7ee16a16 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } +void Allocator::RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); +} + +void Allocator::RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); +} + // If true, cpu allocator collects more stats. static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 774b1fe137..ded120b704 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -23,12 +23,13 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { +class Variant; + // Attributes for a single allocation call. Different calls to the same // allocator could potentially have different allocation attributes. struct AllocationAttributes { @@ -228,13 +229,9 @@ class Allocator { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } - virtual void RunVariantCtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); - } + virtual void RunVariantCtor(Variant* p, size_t n); - virtual void RunVariantDtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); - } + virtual void RunVariantDtor(Variant* p, size_t n); // TODO(jeff): Maybe provide some interface to give info about // current allocation state (total number of bytes available for diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index 24f282ce84..e907c52ba9 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" namespace tensorflow { diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 1a3994736c..4ffd732f8e 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 1b19ab5da3..696fd277cd 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -37,11 +37,12 @@ namespace tensorflow { class AllocationDescription; class Allocator; class OpKernelContext; +class Tensor; class TensorBuffer; class TensorCApi; class TensorDescription; class TensorProto; -class VariantTensorData; + namespace batch_util { Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 84a373c196..9a78cdc91e 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/math/math_util.h" diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 4bda8f9eb8..a7cf600bab 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include <vector> diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 15b1add2c1..2e96b05787 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -39,6 +38,8 @@ limitations under the License. namespace tensorflow { +class Variant; + // MemoryType is used to describe whether input or output Tensors of // an OpKernel should reside in "Host memory" (e.g., CPU memory) or // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 5a507804b0..d43e3c72ec 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,11 +23,11 @@ limitations under the License. namespace tensorflow { -bool Variant::TryDecode(Variant* out) const { - const VariantTensorDataProto* p = get<VariantTensorDataProto>(); - if (p == nullptr) return false; - VariantTensorData data(*p); - return out->Decode(data); +bool Variant::Decode(VariantTensorData data) { + if (!is_empty()) { + return value_->Decode(std::move(data)); + } + return true; } template <> @@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) { template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { - data->FromProto(value); + data->FromConstProto(value); } template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value) { - data.ToProto(value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { + data->ToProto(value); return true; } @@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) { } template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { - return value->ParseFromString(buf); +bool DecodeVariant(string* buf, VariantTensorDataProto* value) { + return value->ParseFromString(*buf); } void EncodeVariantList(const Variant* variant_array, int64 n, @@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, if (variant_array[i].is_empty()) { variant_array[i] = VariantTensorDataProto(); } + // TODO(ebrevdo): Replace with StringPiece? Any way to make this a + // zero-copy operation that keeps a reference to the data in d? string str(d->Data(sizes[i]), sizes[i]); - if (!variant_array[i].Decode(str)) return false; + if (!variant_array[i].Decode(std::move(str))) return false; if (!DecodeUnaryVariant(&variant_array[i])) { LOG(ERROR) << "Could not decode variant with type_name: \"" << variant_array[i].TypeName() diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index 52732801a0..10eabbc85f 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -23,7 +23,6 @@ limitations under the License. #include <unordered_map> #include <utility> -#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/status.h" @@ -38,17 +37,19 @@ string TypeNameVariant(const T& value); template <typename T> string DebugStringVariant(const T& value); +// Allows for specializations of Variant Decoding. `data` may be modified in +// the process of decoding to `value`. template <typename T> -void EncodeVariant(const T& value, VariantTensorData* data); +bool DecodeVariant(VariantTensorData* data, T* value); template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value); +bool DecodeVariant(string* buf, T* value); template <typename T> -void EncodeVariant(const T& value, string* buf); +void EncodeVariant(const T& value, VariantTensorData* data); template <typename T> -bool DecodeVariant(const string& buf, T* value); +void EncodeVariant(const T& value, string* buf); // This is an implementation of a type-erased container that can store an // object of any type. The implementation is very similar to std::any, but has @@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value); // // string TypeName() const; // void Encode(VariantTensorData* data) const; -// void Decode(const VariantTensorData& data); +// void Decode(VariantTensorData data); // // Simple POD types can elide the Encode/Decode functions, they are provided by // helper methods. @@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value); // x.Encode(&serialized_f); // // Variant y = Foo(); // default constructed Foo. -// y.Decode(&serialized_f); +// y.Decode(std::move(serialized_f)); // EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>()); // // @@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value); // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. // EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(), // y_type_unknown.TypeId()); -// // Decode and get y_type_unknown; compare to value in x. -// Foo f_decoded; -// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); -// EXPECT_EQ(f_decoded, f); // class Variant { public: @@ -241,12 +238,7 @@ class Variant { } // Deserialize `data` and update the stored object. - bool Decode(const VariantTensorData& data) { - if (!is_empty()) { - return value_->Decode(data); - } - return true; - } + bool Decode(VariantTensorData data); // Helper methods to directly serialize/deserialize from strings. void Encode(string* buf) const { @@ -254,31 +246,13 @@ class Variant { value_->Encode(buf); } } - bool Decode(const string& buf) { + bool Decode(string buf) { if (!is_empty()) { - return value_->Decode(buf); + return value_->Decode(std::move(buf)); } return true; } - template <typename T> - bool MaybeDecodeAndCopy(T* out) const { - const T* ret = get<T>(); - if (ret != nullptr) { - *out = std::move(*ret); - return true; - }; - Variant decoded = T(); - if (!TryDecode(&decoded)) return false; - T* decoded_ret = decoded.get<T>(); - CHECK_NOTNULL(decoded_ret); - *out = std::move(*decoded_ret); - return true; - } - - private: - bool TryDecode(Variant* out) const; - private: struct in_place_t {}; static constexpr in_place_t in_place{}; @@ -292,9 +266,9 @@ class Variant { virtual string TypeName() const = 0; virtual string DebugString() const = 0; virtual void Encode(VariantTensorData* data) const = 0; - virtual bool Decode(const VariantTensorData& data) = 0; + virtual bool Decode(VariantTensorData data) = 0; virtual void Encode(string* buf) const = 0; - virtual bool Decode(const string& data) = 0; + virtual bool Decode(string data) = 0; }; template <typename T> @@ -325,15 +299,13 @@ class Variant { EncodeVariant(value, data); } - bool Decode(const VariantTensorData& data) override { - return DecodeVariant(data, &value); + bool Decode(VariantTensorData data) override { + return DecodeVariant(&data, &value); } void Encode(string* buf) const override { EncodeVariant(value, buf); } - bool Decode(const string& buf) override { - return DecodeVariant(buf, &value); - } + bool Decode(string buf) override { return DecodeVariant(&buf, &value); } T value; }; diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f155aa4892..5e08e5a7a6 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" @@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value, // Specialization for POD type template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, true /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { @@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for tensorflow::Tensor template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, true /* Tensor */, false /* protobuf */>, T* value) { @@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for protobuf template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, true /* protobuf */>, T* value) { @@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for other types template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { - return value->Decode(data); + return value->Decode(std::move(data)); } template <typename C, typename = void> @@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) { } template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value) { - return DecodeVariantImpl(data, TypeResolver<T>(), value); +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value); } template <typename T> @@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) { } template <typename T> -bool DecodeVariant(const string& buf, T* value) { +bool DecodeVariant(string* buf, T* value) { VariantTensorData data; - if (!data.ParseFromString(buf)) return false; - if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) { + return false; + } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); + template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); + template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); + template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value); +bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 60fa7bd559..daa744e877 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyCPUToGPU); + StoredTensorValue::CopyCPUToGPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST, - "StoredTensorValue", StoredTensorValue::CopyGPUToCPU); + StoredTensorValue::CopyGPUToCPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyGPUToGPU); + StoredTensorValue::CopyGPUToGPU); REGISTER_OP("CreateTestVariant") .Input("input: T") diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index ee07db1aee..ef5b240aea 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - StringPiece type_name) { - auto found = shape_fns.find(type_name); + const TypeIndex& type_index) { + auto found = shape_fns.find(type_index); if (found == shape_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, +void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index, const VariantShapeFn& shape_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; - VariantShapeFn* existing = GetShapeFn(type_name); + VariantShapeFn* existing = GetShapeFn(type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantShapeFn for type_name: " << type_name - << " already registered"; - shape_fns.insert(std::pair<StringPiece, VariantShapeFn>( - GetPersistentStringPiece(type_name), shape_fn)); + << "Unary VariantShapeFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered"; + shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { CHECK_EQ(variant_tensor.dims(), 0); const Variant& v = variant_tensor.scalar<Variant>()(); UnaryVariantOpRegistry::VariantShapeFn* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId()); if (shape_fn == nullptr) { return errors::Internal( - "No unary variant shape function found for Variant type_name: ", - v.TypeName()); + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(v.TypeId().name())); } return (*shape_fn)(v, shape); } @@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) { } // namespace #define REGISTER_VARIANT_SHAPE_TYPE(T) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>); + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>); // No encode/shape registered for std::complex<> and Eigen::half // objects yet. @@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double); UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* UnaryVariantOpRegistry::GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name) { - auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); if (found == device_copy_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterDeviceCopyFn( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; - AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); CHECK_EQ(existing, nullptr) << "UnaryVariantDeviceCopy for direction: " << direction - << " and type_name: " << type_name << " already registered"; + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; device_copy_fns.insert( - std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn>( - std::make_pair(direction, GetPersistentStringPiece(type_name)), - device_copy_fn)); + std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index), + device_copy_fn)); } Status VariantDeviceCopy( @@ -170,35 +167,34 @@ Status VariantDeviceCopy( const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, - from.TypeName()); + from.TypeId()); if (device_copy_fn == nullptr) { return errors::Internal( "No unary variant device copy function found for direction: ", - direction, " and Variant type_name: ", from.TypeName()); + direction, " and Variant type_index: ", + port::MaybeAbiDemangle(from.TypeId().name())); } return (*device_copy_fn)(from, to, copy_fn); } // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( - VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find({op, device, type_name}); + VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterUnaryOpFn( - VariantUnaryOp op, const string& device, const string& type_name, + VariantUnaryOp op, const string& device, const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; - VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantUnaryOpFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - unary_op_fn)); + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); } namespace { @@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ - DEVICE_CPU, T, TF_STR(T), \ + DEVICE_CPU, T, \ ZerosLikeVariantPrimitiveType<T>); // No zeros_like registered for std::complex<> or Eigen::half objects yet. @@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); // Special casing BinaryOpFn per op and per device. UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name) { - auto found = binary_op_fns.find({op, device, type_name}); + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterBinaryOpFn( - VariantBinaryOp op, const string& device, const string& type_name, + VariantBinaryOp op, const string& device, const TypeIndex& type_index, const VariantBinaryOpFn& add_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; - VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantBinaryOpFn for type_name: " << type_name + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - add_fn)); + {op, GetPersistentStringPiece(device), type_index}, add_fn)); } namespace { @@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, #define REGISTER_VARIANT_ADD_TYPE(T) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ - T, TF_STR(T), \ - AddVariantPrimitiveType<T>); + T, AddVariantPrimitiveType<T>); // No add registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ADD_TYPE(int); diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index e6a2665a56..7eb37e859f 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -22,10 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/abi.h" namespace tensorflow { @@ -90,10 +94,11 @@ class UnaryVariantOpRegistry { AsyncVariantDeviceCopyFn; // Add a shape lookup function to the registry. - void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); + void RegisterShapeFn(const TypeIndex& type_index, + const VariantShapeFn& shape_fn); - // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(StringPiece type_name); + // Returns nullptr if no shape function was found for the given TypeIndex. + VariantShapeFn* GetShapeFn(const TypeIndex& type_index); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, @@ -104,33 +109,33 @@ class UnaryVariantOpRegistry { // Add a copy-to-GPU function to the registry. void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, - const string& type_name, + const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn); // Returns nullptr if no copy function was found for the given // TypeName and direction. AsyncVariantDeviceCopyFn* GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index); // Add a unary op function to the registry. void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn); // Returns nullptr if no unary op function was found for the given // op, device, and TypeName. VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Add a binary op function to the registry. void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantBinaryOpFn& add_fn); // Returns nullptr if no binary op function was found for the given // op, device and TypeName. VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); @@ -145,24 +150,26 @@ class UnaryVariantOpRegistry { static std::unordered_set<string>* PersistentStringStorage(); private: - std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns; - std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher> - decode_fns; + struct TypeIndexHash { + std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } + }; + + gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns; + gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns; // Map std::pair<Direction, type_name> to function. struct PairHash { template <typename Direction> - std::size_t operator()(const std::pair<Direction, StringPiece>& x) const { + std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); - ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, std::get<1>(x).hash_code()); return ret; } - StringPieceHasher sp_hasher_; }; - std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn, PairHash> + gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn, PairHash> device_copy_fns; // Map std::tuple<Op, device, type_name> to function. @@ -172,10 +179,11 @@ class UnaryVariantOpRegistry { // and references therein template <typename Op> struct FuncTuple { - FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) - : op_type_(op), device_(dev), typename_(tname){}; + FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) + : op_type_(op), device_(dev), type_index_(type_index) {} Op op_type_; - StringPiece device_, typename_; + StringPiece device_; + TypeIndex type_index_; }; // friend declaration for operator== // needed for clang @@ -184,11 +192,11 @@ class UnaryVariantOpRegistry { struct TupleHash { template <typename Op> std::size_t operator()( - const std::tuple<Op, StringPiece, StringPiece>& x) const { + const std::tuple<Op, StringPiece, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); - ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); + ret = Hash64Combine(ret, std::get<2>(x).hash_code()); return ret; } @@ -197,14 +205,14 @@ class UnaryVariantOpRegistry { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(x.op_type_); ret = Hash64Combine(ret, sp_hasher_(x.device_)); - ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + ret = Hash64Combine(ret, x.type_index_.hash_code()); return ret; } StringPieceHasher sp_hasher_; }; - std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage @@ -225,7 +233,7 @@ template <typename Op> inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs, const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) { return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && - (lhs.typename_ == rhs.typename_); + (lhs.type_index_ == rhs.type_index_); } // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape @@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, Variant* v_out) { const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = - UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); if (unary_op_fn == nullptr) { return errors::Internal( "No unary variant unary_op function found for unary variant op enum: ", @@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, template <typename Device> Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, const Variant& a, const Variant& b, Variant* out) { - if (a.TypeName() != b.TypeName()) { + if (a.TypeId() != b.TypeId()) { return errors::Internal( "BianryOpVariants: Variants a and b have different " - "type names: '", + "type ids. Type names: '", a.TypeName(), "' vs. '", b.TypeName(), "'"); } const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = - UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); if (binary_op_fn == nullptr) { return errors::Internal( "No unary variant binary_op function found for binary variant op " @@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration { public: typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn; - UnaryVariantShapeRegistration(const string& type_name, + UnaryVariantShapeRegistration(const TypeIndex& type_index, const LocalVariantShapeFn& shape_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterShapeFn( - type_name, - [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status { + type_index, + [type_index_name, shape_fn](const Variant& v, + TensorShape* s) -> Status { const T* t = v.get<T>(); if (t == nullptr) { return errors::Internal( - "VariantShapeFn: Could not access object, type_name: ", - type_name); + "VariantShapeFn: Could not access object, type_index: ", + type_index_name); } return shape_fn(*t, s); }); @@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration { return false; } Variant decoded = T(); - VariantTensorData data(*t); - if (!decoded.Decode(data)) { + VariantTensorData data(std::move(*t)); + if (!decoded.Decode(std::move(data))) { return false; } - *v = std::move(decoded); + std::swap(decoded, *v); return true; }); } @@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration { UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)> LocalVariantDeviceCopyFn; UnaryVariantDeviceCopyRegistration( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const LocalVariantDeviceCopyFn& device_copy_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( - direction, type_name, - [type_name, device_copy_fn]( + direction, type_index, + [type_index_name, device_copy_fn]( const Variant& from, Variant* to, UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn device_copy_tensor_fn) -> Status { @@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration { *to = T(); if (from.get<T>() == nullptr) { return errors::Internal( - "VariantCopyToGPUFn: Could not access object, type_name: ", - type_name); + "VariantCopyToGPUFn: Could not access object, type_index: ", + type_index_name); } const T& t = *from.get<T>(); T* t_out = to->get<T>(); @@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration { public: UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantUnaryOpFn& unary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( - op, device, type_name, - [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, - Variant* v_out) -> Status { + op, device, type_index, + [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> Status { DCHECK_NE(v_out, nullptr); *v_out = T(); if (v.get<T>() == nullptr) { return errors::Internal( - "VariantUnaryOpFn: Could not access object, type_name: ", - type_name); + "VariantUnaryOpFn: Could not access object, type_index: ", + type_index_name); } const T& t = *v.get<T>(); T* t_out = v_out->get<T>(); @@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration { public: UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantBinaryOpFn& binary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( - op, device, type_name, - [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, - const Variant& b, Variant* out) -> Status { + op, device, type_index, + [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, + Variant* out) -> Status { DCHECK_NE(out, nullptr); *out = T(); if (a.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'a', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'a', type_index: ", + type_index_name); } if (b.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'b', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'b', type_index: ", + type_index_name); } const T& t_a = *a.get<T>(); const T& t_b = *b.get<T>(); @@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration { // Register a unary shape variant function with the signature: // Status ShapeFn(const T& t, TensorShape* s); -// to Variants having TypeName type_name. -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \ - shape_function) +// to Variants having TypeIndex type_index. +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, MakeTypeIndex<T>(), shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \ - shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function) +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \ + shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \ +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \ shape_function) \ static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \ - register_unary_variant_op_shape_registration_fn_##ctr(type_name, \ + register_unary_variant_op_shape_registration_fn_##ctr(type_index, \ shape_function) // Register a unary decode variant function for the given type. @@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration { // ****** NOTE ****** // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. // ****** NOTE ****** -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - T, direction, type_name, device_copy_fn) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, T, direction, type_name, device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ + device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn) #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - ctr, T, direction, type_name, device_copy_fn) \ + ctr, T, direction, type_index, device_copy_fn) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) + ctr, T, direction, type_index, device_copy_fn) -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantDeviceCopyRegistration<T> \ - register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \ - device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration<T> \ + register_unary_variant_op_device_copy_fn_##ctr( \ + direction, type_index, device_copy_fn) // Register a unary unary_op variant function with the signature: // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for UnaryVariantOp enum op. -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ - unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function) -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ - unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ + type_index, unary_op_function) #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, unary_op_function) \ + ctr, op, device, T, type_index, unary_op_function) \ static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ unary_op_function) // Register a binary_op variant function with the signature: // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for BinaryVariantOp enum OP. -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ - binary_op_function) \ - REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, binary_op_function) +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function) #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, binary_op_function) \ + ctr, op, device, T, type_index, binary_op_function) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) + ctr, op, device, T, type_index, binary_op_function) -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantBinaryOpRegistration<T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration<T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ binary_op_function) } // end namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 7055e62c0e..b2443e8676 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -89,41 +89,37 @@ struct VariantValue { int value; }; -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", - VariantValue::ShapeFn); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + VariantValue::CPUToGPUCopyFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, VariantValue, - "TEST VariantValue", VariantValue::CPUZerosLikeFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, VariantValue, - "TEST VariantValue", VariantValue::GPUZerosLikeFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - VariantValue, "TEST VariantValue", - VariantValue::CPUAddFn); + VariantValue, VariantValue::CPUAddFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - VariantValue, "TEST VariantValue", - VariantValue::GPUAddFn); + VariantValue, VariantValue::GPUAddFn); } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + class Blah {}; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()), nullptr); - auto* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn( + MakeTypeIndex<VariantValue>()); EXPECT_NE(shape_fn, nullptr); TensorShape shape; @@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - string kTypeName = "fjfjfj"; - registry.RegisterShapeFn(kTypeName, f); - EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), - "fjfjfj already registered"); + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); + registry.RegisterShapeFn(kTypeIndex, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpDecodeRegistryTest, TestBasic) { @@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { TEST(VariantOpCopyToGPURegistryTest, TestBasic) { // No registered copy fn for GPU<->GPU. - EXPECT_EQ( - UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), - nullptr); + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + MakeTypeIndex<VariantValue>()), + nullptr); auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, + MakeTypeIndex<VariantValue>()); EXPECT_NE(copy_to_gpu_fn, nullptr); VariantValue vv{true /* early_exit */}; @@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, - kTypeName, f); + kTypeIndex, f); EXPECT_DEATH(registry.RegisterDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), - "fjfjfj already registered"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantUnaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_CPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_CPU, kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_GPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_GPU, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpAddRegistryTest, TestBasicCPU) { - return; + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpAddRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { TEST(VariantOpAddRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantBinaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); } } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc index 99712dc114..3e67e4a864 100644 --- a/tensorflow/core/framework/variant_tensor_data.cc +++ b/tensorflow/core/framework/variant_tensor_data.cc @@ -22,8 +22,8 @@ namespace tensorflow { VariantTensorData::VariantTensorData() {} -VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) { - FromProto(proto); +VariantTensorData::VariantTensorData(VariantTensorDataProto proto) { + FromProto(std::move(proto)); } VariantTensorData::~VariantTensorData() {} @@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { } } -bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { +bool VariantTensorData::FromProto(VariantTensorDataProto proto) { + // TODO(ebrevdo): Do this lazily. + set_type_name(proto.type_name()); + set_metadata(proto.metadata()); + for (const auto& tensor : proto.tensors()) { + Tensor tmp; + if (!tmp.FromProto(tensor)) return false; + tensors_.push_back(tmp); + } + return true; +} + +bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) { set_type_name(proto.type_name()); set_metadata(proto.metadata()); for (const auto& tensor : proto.tensors()) { @@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) { return proto.SerializeToString(buf); } -bool VariantTensorData::ParseFromString(const string& s) { +bool VariantTensorData::ParseFromString(string s) { VariantTensorDataProto proto; const bool status = proto.ParseFromString(s); - if (status) FromProto(proto); + if (status) FromProto(std::move(proto)); return status; } diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index 7500e77d43..8a240ee1e3 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -19,13 +19,13 @@ limitations under the License. #include <algorithm> #include <vector> +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { class VariantTensorDataProto; -class Tensor; // The serialization format for Variant objects. Objects with references to // other Tensors can simply store those tensors in the `tensors` field, and @@ -38,7 +38,7 @@ class Tensor; class VariantTensorData { public: VariantTensorData(); - VariantTensorData(const VariantTensorDataProto& proto); + VariantTensorData(VariantTensorDataProto proto); ~VariantTensorData(); // Name of the type of objects being serialized. @@ -68,12 +68,14 @@ class VariantTensorData { // Conversion to and from VariantTensorDataProto void ToProto(VariantTensorDataProto* proto) const; - bool FromProto(const VariantTensorDataProto& proto); + // This allows optimizations via std::move. + bool FromProto(VariantTensorDataProto proto); + bool FromConstProto(const VariantTensorDataProto& proto); // Serialization via VariantTensorDataProto string SerializeAsString() const; bool SerializeToString(string* buf); - bool ParseFromString(const string& s); + bool ParseFromString(string s); string DebugString() const; diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc index eef5c47d15..08d09de7b8 100644 --- a/tensorflow/core/framework/variant_test.cc +++ b/tensorflow/core/framework/variant_test.cc @@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) { struct TensorList { void Encode(VariantTensorData* data) const { data->tensors_ = vec; } - bool Decode(const VariantTensorData& data) { - vec = data.tensors_; + bool Decode(VariantTensorData data) { + vec = std::move(data.tensors_); return true; } @@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) { x.Encode(&serialized); Variant y = TensorList(); - y.Decode(serialized); + y.Decode(std::move(serialized)); const TensorList& decoded_vec = *y.get<TensorList>(); for (int i = 0; i < 4; ++i) { @@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) { EXPECT_EQ(y_unknown.DebugString(), strings::StrCat( "Variant<type: TensorList value: ", data.DebugString(), ">")); - - TensorList unknown_decoded_vec; - EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec)); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i); - } - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i); - } } TEST(VariantTest, VariantArray) { |