diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-09-08 11:14:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-08 11:18:21 -0700 |
commit | 477a221a2ffee7261220ad6c0f4f8c76a5eb7931 (patch) | |
tree | 280e78e8f928390d886e6b0e5a03a71c588a51a5 /tensorflow/core/framework/variant_op_registry.h | |
parent | 96828c9f5276a759717e0d9574b34bcd456d11a5 (diff) |
Modify variant registry to have UnaryOp and BinaryOp registrations. Speed up registry lookup.
* Op type is described as an enum (separate enums for unary and binary ops).
* Modified ZerosLike registrations to unary registrations with ZEROS_LIKE enum.
* Added Add(a,b) registrations as binary registrations with ADD enum.
* AddN op uses ADD BinaryOp registrations and ZerosLike op modified to use
ZEROS_LIKE UnaryOp registrations.
* Modified the registry tables' keys from string type to StringPiece type.
The reduced copying should speed up registry lookups by ops. Required creating
a backing store for device and type_name strings passed in at registration.
PiperOrigin-RevId: 168020449
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.h | 275 |
1 files changed, 213 insertions, 62 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 37e54f82c0..2e9f2243ad 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -17,11 +17,13 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ #include <string> +#include <unordered_set> #include <vector> #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { @@ -30,49 +32,110 @@ class OpKernelContext; // for different variant types. To be used by ShapeOp, RankOp, and // SizeOp, decoding, etc. +enum VariantUnaryOp { + INVALID_VARIANT_UNARY_OP = 0, + ZEROS_LIKE_VARIANT_UNARY_OP = 1, +}; + +enum VariantBinaryOp { + INVALID_VARIANT_BINARY_OP = 0, + ADD_VARIANT_BINARY_OP = 1, +}; + class UnaryVariantOpRegistry { public: typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn; typedef std::function<bool(Variant*)> VariantDecodeFn; typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)> - VariantZerosLikeFn; + VariantUnaryOpFn; + typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&, + Variant*)> + VariantBinaryOpFn; // Add a shape lookup function to the registry. void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(const string& type_name); + VariantShapeFn* GetShapeFn(StringPiece type_name); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, const VariantDecodeFn& decode_fn); // Returns nullptr if no decode function was found for the given TypeName. - VariantDecodeFn* GetDecodeFn(const string& type_name); + VariantDecodeFn* GetDecodeFn(StringPiece type_name); + + // Add a unary op function to the registry. + void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, + const string& type_name, + const VariantUnaryOpFn& unary_op_fn); - // Add a zeros-like function to the registry. - void RegisterZerosLikeFn(const string& device, const string& type_name, - const VariantZerosLikeFn& zeros_like_fn); + // Returns nullptr if no unary op function was found for the given + // op, device, and TypeName. + VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, + StringPiece type_name); - // Returns nullptr if no zeros-like function was found for the given - // device and TypeName. - VariantZerosLikeFn* GetZerosLikeFn(const string& device, - const string& type_name); + // Add a binary op function to the registry. + void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, + const string& type_name, + const VariantBinaryOpFn& add_fn); + // Returns nullptr if no binary op function was found for the given + // op, device and TypeName. + VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + StringPiece type_name); + + // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); + // Get a pointer to a global persistent string storage object. + // ISO/IEC C++ working draft N4296 clarifies that insertion into an + // std::unordered_set does not invalidate memory locations of + // *values* inside the set (though it may invalidate existing + // iterators). In other words, one may safely point a StringPiece to + // a value in the set without that StringPiece being invalidated by + // future insertions. + static std::unordered_set<string>* PersistentStringStorage(); + private: - std::unordered_map<string, VariantShapeFn> shape_fns; - std::unordered_map<string, VariantDecodeFn> decode_fns; - // Map std::pair<device, type_name> to function. - struct PairHash { - template <typename T, typename U> - std::size_t operator()(const std::pair<T, U>& x) const { - return std::hash<T>()(x.first) ^ std::hash<U>()(x.second); + std::unordered_map<StringPiece, VariantShapeFn, StringPiece::Hasher> + shape_fns; + std::unordered_map<StringPiece, VariantDecodeFn, StringPiece::Hasher> + decode_fns; + + // Map std::tuple<Op, device, type_name> to function. + struct TupleHash { + template <typename Op> + std::size_t operator()( + const std::tuple<Op, StringPiece, StringPiece>& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); + StringPiece::Hasher sp_hasher; + ret = Hash64Combine(ret, sp_hasher(std::get<1>(x))); + ret = Hash64Combine(ret, sp_hasher(std::get<2>(x))); + return ret; } }; - std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash> - zeros_like_fns; + std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>, + VariantUnaryOpFn, TupleHash> + unary_op_fns; + std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>, + VariantBinaryOpFn, TupleHash> + binary_op_fns; + + // Find or insert a string into a persistent string storage + // container; return the StringPiece pointing to the permanent + // string location. + static StringPiece GetPersistentStringPiece(const string& str) { + const auto string_storage = PersistentStringStorage(); + auto found = string_storage->find(str); + if (found == string_storage->end()) { + auto inserted = string_storage->insert(str); + return StringPiece(*inserted.first); + } else { + return StringPiece(*found); + } + } }; // Gets a TensorShape from a Tensor containing a scalar Variant. @@ -94,26 +157,57 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape); // bool DecodeUnaryVariant(Variant* variant); -// Sets *z_out = zeros_like(v). The variant v must have a registered -// ZerosLike function for the given Device. Returns an Internal error -// if v does not have a registered zeros_like function for this device, or if -// ZerosLike fails. +// Sets *v_out = unary_op(v). The variant v must have a registered +// UnaryOp function for the given Device. Returns an Internal error +// if v does not have a registered unary_op function for this device, or if +// UnaryOp fails. // // REQUIRES: // v_out is not null. // template <typename Device> -Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v, - Variant* v_out) { +Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, + Variant* v_out) { const string& device = DeviceName<Device>::value; - UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn = - UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName()); - if (zeros_like_fn == nullptr) { + UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + if (unary_op_fn == nullptr) { return errors::Internal( - "No unary variant zeros_like function found for Variant type_name: ", - v.TypeName(), " for device type: ", device); + "No unary variant unary_op function found for unary variant op enum: ", + op, " Variant type_name: ", v.TypeName(), " for device type: ", device); } - return (*zeros_like_fn)(ctx, v, v_out); + return (*unary_op_fn)(ctx, v, v_out); +} + +// Sets *out = binary_op(a, b). The variants a and b must be the same type +// and have a registered binary_op function for the given Device. Returns an +// Internal error if a and b are not the same type_name or if +// if a does not have a registered op function for this device, or if +// BinaryOp fails. +// +// REQUIRES: +// out is not null. +// +template <typename Device> +Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, + const Variant& a, const Variant& b, Variant* out) { + if (a.TypeName() != b.TypeName()) { + return errors::Internal( + "BianryOpVariants: Variants a and b have different " + "type names: '", + a.TypeName(), "' vs. '", b.TypeName(), "'"); + } + const string& device = DeviceName<Device>::value; + UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + if (binary_op_fn == nullptr) { + return errors::Internal( + "No unary variant binary_op function found for binary variant op " + "enum: ", + op, " Variant type_name: '", a.TypeName(), + "' for device type: ", device); + } + return (*binary_op_fn)(ctx, a, b, out); } namespace variant_op_registry_fn_registration { @@ -165,30 +259,65 @@ class UnaryVariantDecodeRegistration { }; template <typename T> -class UnaryVariantZerosLikeRegistration { +class UnaryVariantUnaryOpRegistration { typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)> - LocalVariantZerosLikeFn; + LocalVariantUnaryOpFn; public: - UnaryVariantZerosLikeRegistration( - const string& device, const string& type_name, - const LocalVariantZerosLikeFn& zeros_like_fn) { - auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx, - const Variant& v, - Variant* v_out) -> Status { + UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, + const string& type_name, + const LocalVariantUnaryOpFn& unary_op_fn) { + auto wrapped_fn = [type_name, unary_op_fn](OpKernelContext* ctx, + const Variant& v, + Variant* v_out) -> Status { CHECK_NOTNULL(v_out); *v_out = T(); if (v.get<T>() == nullptr) { return errors::Internal( - "VariantZerosLikeFn: Could not access object, type_name: ", + "VariantUnaryOpFn: Could not access object, type_name: ", type_name); } const T& t = *v.get<T>(); T* t_out = v_out->get<T>(); - return zeros_like_fn(ctx, t, t_out); + return unary_op_fn(ctx, t, t_out); + }; + UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(op, device, type_name, + wrapped_fn); + } +}; + +template <typename T> +class UnaryVariantBinaryOpRegistration { + typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b, + T* out)> + LocalVariantBinaryOpFn; + + public: + UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, + const string& type_name, + const LocalVariantBinaryOpFn& binary_op_fn) { + auto wrapped_fn = [type_name, binary_op_fn]( + OpKernelContext* ctx, const Variant& a, + const Variant& b, Variant* out) -> Status { + CHECK_NOTNULL(out); + *out = T(); + if (a.get<T>() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'a', type_name: ", + type_name); + } + if (b.get<T>() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'b', type_name: ", + type_name); + } + const T& t_a = *a.get<T>(); + const T& t_b = *b.get<T>(); + T* t_out = out->get<T>(); + return binary_op_fn(ctx, t_a, t_b, t_out); }; - UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name, - wrapped_fn); + UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(op, device, type_name, + wrapped_fn); } }; @@ -223,25 +352,47 @@ class UnaryVariantZerosLikeRegistration { T> \ register_unary_variant_op_decoder_fn_##ctr(type_name) -// Register a unary zeros_like variant function with the signature: -// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device. -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \ - zeros_like_function) \ - REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, device, T, type_name, zeros_like_function) - -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ - ctr, device, T, type_name, zeros_like_function) \ - REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \ - zeros_like_function) - -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \ - ctr, device, T, type_name, zeros_like_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantZerosLikeRegistration<T> \ - register_unary_variant_op_decoder_fn_##ctr(device, type_name, \ - zeros_like_function) +// Register a unary unary_op variant function with the signature: +// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeName type_name, for device string device, +// for UnaryVariantOp enum op. +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ + unary_op_function) + +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, unary_op_function) \ + static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ + T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + unary_op_function) + +// Register a binary_op variant function with the signature: +// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); +// to Variants having TypeName type_name, for device string device, +// for BinaryVariantOp enum OP. +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration<T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + binary_op_function) } // end namespace tensorflow |