diff options
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.h | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 389b049fa0..37e54f82c0 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -19,11 +19,13 @@ limitations under the License. #include <string> #include <vector> +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" namespace tensorflow { +class OpKernelContext; // A global UnaryVariantOpRegistry is used to hold callback functions // for different variant types. To be used by ShapeOp, RankOp, and // SizeOp, decoding, etc. @@ -32,6 +34,8 @@ class UnaryVariantOpRegistry { public: typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn; typedef std::function<bool(Variant*)> VariantDecodeFn; + typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)> + VariantZerosLikeFn; // Add a shape lookup function to the registry. void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); @@ -46,11 +50,29 @@ class UnaryVariantOpRegistry { // Returns nullptr if no decode function was found for the given TypeName. VariantDecodeFn* GetDecodeFn(const string& type_name); + // Add a zeros-like function to the registry. + void RegisterZerosLikeFn(const string& device, const string& type_name, + const VariantZerosLikeFn& zeros_like_fn); + + // Returns nullptr if no zeros-like function was found for the given + // device and TypeName. + VariantZerosLikeFn* GetZerosLikeFn(const string& device, + const string& type_name); + static UnaryVariantOpRegistry* Global(); private: std::unordered_map<string, VariantShapeFn> shape_fns; std::unordered_map<string, VariantDecodeFn> decode_fns; + // Map std::pair<device, type_name> to function. + struct PairHash { + template <typename T, typename U> + std::size_t operator()(const std::pair<T, U>& x) const { + return std::hash<T>()(x.first) ^ std::hash<U>()(x.second); + } + }; + std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash> + zeros_like_fns; }; // Gets a TensorShape from a Tensor containing a scalar Variant. @@ -72,6 +94,28 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape); // bool DecodeUnaryVariant(Variant* variant); +// Sets *z_out = zeros_like(v). The variant v must have a registered +// ZerosLike function for the given Device. Returns an Internal error +// if v does not have a registered zeros_like function for this device, or if +// ZerosLike fails. +// +// REQUIRES: +// v_out is not null. +// +template <typename Device> +Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v, + Variant* v_out) { + const string& device = DeviceName<Device>::value; + UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn = + UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName()); + if (zeros_like_fn == nullptr) { + return errors::Internal( + "No unary variant zeros_like function found for Variant type_name: ", + v.TypeName(), " for device type: ", device); + } + return (*zeros_like_fn)(ctx, v, v_out); +} + namespace variant_op_registry_fn_registration { template <typename T> @@ -120,6 +164,34 @@ class UnaryVariantDecodeRegistration { } }; +template <typename T> +class UnaryVariantZerosLikeRegistration { + typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)> + LocalVariantZerosLikeFn; + + public: + UnaryVariantZerosLikeRegistration( + const string& device, const string& type_name, + const LocalVariantZerosLikeFn& zeros_like_fn) { + auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx, + const Variant& v, + Variant* v_out) -> Status { + CHECK_NOTNULL(v_out); + *v_out = T(); + if (v.get<T>() == nullptr) { + return errors::Internal( + "VariantZerosLikeFn: Could not access object, type_name: ", + type_name); + } + const T& t = *v.get<T>(); + T* t_out = v_out->get<T>(); + return zeros_like_fn(ctx, t, t_out); + }; + UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name, + wrapped_fn); + } +}; + }; // namespace variant_op_registry_fn_registration // Register a unary shape variant function with the signature: @@ -151,6 +223,26 @@ class UnaryVariantDecodeRegistration { T> \ register_unary_variant_op_decoder_fn_##ctr(type_name) +// Register a unary zeros_like variant function with the signature: +// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeName type_name, for device string device. +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \ + zeros_like_function) \ + REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, device, T, type_name, zeros_like_function) + +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ + ctr, device, T, type_name, zeros_like_function) \ + REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \ + zeros_like_function) + +#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \ + ctr, device, T, type_name, zeros_like_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantZerosLikeRegistration<T> \ + register_unary_variant_op_decoder_fn_##ctr(device, type_name, \ + zeros_like_function) + } // end namespace tensorflow #endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ |