aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_registry.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-09-01 13:32:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-01 13:36:03 -0700
commit0acf5bb38a8f208c6d9f048579a076d5bc6ff0be (patch)
tree04b72d4f0d0d50452098399a42611ef3aecf9c63 /tensorflow/core/framework/variant_op_registry.h
parentf7733742d51dba09d4f222b3eb027c27c2c4d130 (diff)
Added registry for variants for op ZerosLike.
* Updated tf.zeros_like python wrapper to avoid calling tf.zeros on Variants. * tf.zeros_like for variants calls the appropriate callback for the given device and type_name. PiperOrigin-RevId: 167317274
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r--tensorflow/core/framework/variant_op_registry.h92
1 files changed, 92 insertions, 0 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index 389b049fa0..37e54f82c0 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -19,11 +19,13 @@ limitations under the License.
#include <string>
#include <vector>
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
namespace tensorflow {
+class OpKernelContext;
// A global UnaryVariantOpRegistry is used to hold callback functions
// for different variant types. To be used by ShapeOp, RankOp, and
// SizeOp, decoding, etc.
@@ -32,6 +34,8 @@ class UnaryVariantOpRegistry {
public:
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
typedef std::function<bool(Variant*)> VariantDecodeFn;
+ typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
+ VariantZerosLikeFn;
// Add a shape lookup function to the registry.
void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
@@ -46,11 +50,29 @@ class UnaryVariantOpRegistry {
// Returns nullptr if no decode function was found for the given TypeName.
VariantDecodeFn* GetDecodeFn(const string& type_name);
+ // Add a zeros-like function to the registry.
+ void RegisterZerosLikeFn(const string& device, const string& type_name,
+ const VariantZerosLikeFn& zeros_like_fn);
+
+ // Returns nullptr if no zeros-like function was found for the given
+ // device and TypeName.
+ VariantZerosLikeFn* GetZerosLikeFn(const string& device,
+ const string& type_name);
+
static UnaryVariantOpRegistry* Global();
private:
std::unordered_map<string, VariantShapeFn> shape_fns;
std::unordered_map<string, VariantDecodeFn> decode_fns;
+ // Map std::pair<device, type_name> to function.
+ struct PairHash {
+ template <typename T, typename U>
+ std::size_t operator()(const std::pair<T, U>& x) const {
+ return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
+ }
+ };
+ std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash>
+ zeros_like_fns;
};
// Gets a TensorShape from a Tensor containing a scalar Variant.
@@ -72,6 +94,28 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
//
bool DecodeUnaryVariant(Variant* variant);
+// Sets *z_out = zeros_like(v). The variant v must have a registered
+// ZerosLike function for the given Device. Returns an Internal error
+// if v does not have a registered zeros_like function for this device, or if
+// ZerosLike fails.
+//
+// REQUIRES:
+// v_out is not null.
+//
+template <typename Device>
+Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) {
+ const string& device = DeviceName<Device>::value;
+ UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn =
+ UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName());
+ if (zeros_like_fn == nullptr) {
+ return errors::Internal(
+ "No unary variant zeros_like function found for Variant type_name: ",
+ v.TypeName(), " for device type: ", device);
+ }
+ return (*zeros_like_fn)(ctx, v, v_out);
+}
+
namespace variant_op_registry_fn_registration {
template <typename T>
@@ -120,6 +164,34 @@ class UnaryVariantDecodeRegistration {
}
};
+template <typename T>
+class UnaryVariantZerosLikeRegistration {
+ typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
+ LocalVariantZerosLikeFn;
+
+ public:
+ UnaryVariantZerosLikeRegistration(
+ const string& device, const string& type_name,
+ const LocalVariantZerosLikeFn& zeros_like_fn) {
+ auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx,
+ const Variant& v,
+ Variant* v_out) -> Status {
+ CHECK_NOTNULL(v_out);
+ *v_out = T();
+ if (v.get<T>() == nullptr) {
+ return errors::Internal(
+ "VariantZerosLikeFn: Could not access object, type_name: ",
+ type_name);
+ }
+ const T& t = *v.get<T>();
+ T* t_out = v_out->get<T>();
+ return zeros_like_fn(ctx, t, t_out);
+ };
+ UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name,
+ wrapped_fn);
+ }
+};
+
}; // namespace variant_op_registry_fn_registration
// Register a unary shape variant function with the signature:
@@ -151,6 +223,26 @@ class UnaryVariantDecodeRegistration {
T> \
register_unary_variant_op_decoder_fn_##ctr(type_name)
+// Register a unary zeros_like variant function with the signature:
+// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out);
+// to Variants having TypeName type_name, for device string device.
+#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \
+ zeros_like_function) \
+ REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, device, T, type_name, zeros_like_function)
+
+#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
+ ctr, device, T, type_name, zeros_like_function) \
+ REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \
+ zeros_like_function)
+
+#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \
+ ctr, device, T, type_name, zeros_like_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantZerosLikeRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(device, type_name, \
+ zeros_like_function)
+
} // end namespace tensorflow
#endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_