aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_registry.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-09-08 11:14:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-08 11:18:21 -0700
commit477a221a2ffee7261220ad6c0f4f8c76a5eb7931 (patch)
tree280e78e8f928390d886e6b0e5a03a71c588a51a5 /tensorflow/core/framework/variant_op_registry.h
parent96828c9f5276a759717e0d9574b34bcd456d11a5 (diff)
Modify variant registry to have UnaryOp and BinaryOp registrations. Speed up registry lookup.
* Op type is described as an enum (separate enums for unary and binary ops). * Modified ZerosLike registrations to unary registrations with ZEROS_LIKE enum. * Added Add(a,b) registrations as binary registrations with ADD enum. * AddN op uses ADD BinaryOp registrations and ZerosLike op modified to use ZEROS_LIKE UnaryOp registrations. * Modified the registry tables' keys from string type to StringPiece type. The reduced copying should speed up registry lookups by ops. Required creating a backing store for device and type_name strings passed in at registration. PiperOrigin-RevId: 168020449
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r--tensorflow/core/framework/variant_op_registry.h275
1 files changed, 213 insertions, 62 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index 37e54f82c0..2e9f2243ad 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -17,11 +17,13 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
#include <string>
+#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
@@ -30,49 +32,110 @@ class OpKernelContext;
// for different variant types. To be used by ShapeOp, RankOp, and
// SizeOp, decoding, etc.
+enum VariantUnaryOp {
+ INVALID_VARIANT_UNARY_OP = 0,
+ ZEROS_LIKE_VARIANT_UNARY_OP = 1,
+};
+
+enum VariantBinaryOp {
+ INVALID_VARIANT_BINARY_OP = 0,
+ ADD_VARIANT_BINARY_OP = 1,
+};
+
class UnaryVariantOpRegistry {
public:
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
typedef std::function<bool(Variant*)> VariantDecodeFn;
typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
- VariantZerosLikeFn;
+ VariantUnaryOpFn;
+ typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
+ Variant*)>
+ VariantBinaryOpFn;
// Add a shape lookup function to the registry.
void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
// Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(const string& type_name);
+ VariantShapeFn* GetShapeFn(StringPiece type_name);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
const VariantDecodeFn& decode_fn);
// Returns nullptr if no decode function was found for the given TypeName.
- VariantDecodeFn* GetDecodeFn(const string& type_name);
+ VariantDecodeFn* GetDecodeFn(StringPiece type_name);
+
+ // Add a unary op function to the registry.
+ void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
+ const string& type_name,
+ const VariantUnaryOpFn& unary_op_fn);
- // Add a zeros-like function to the registry.
- void RegisterZerosLikeFn(const string& device, const string& type_name,
- const VariantZerosLikeFn& zeros_like_fn);
+ // Returns nullptr if no unary op function was found for the given
+ // op, device, and TypeName.
+ VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
+ StringPiece type_name);
- // Returns nullptr if no zeros-like function was found for the given
- // device and TypeName.
- VariantZerosLikeFn* GetZerosLikeFn(const string& device,
- const string& type_name);
+ // Add a binary op function to the registry.
+ void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
+ const string& type_name,
+ const VariantBinaryOpFn& add_fn);
+ // Returns nullptr if no binary op function was found for the given
+ // op, device and TypeName.
+ VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
+ StringPiece type_name);
+
+ // Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
+ // Get a pointer to a global persistent string storage object.
+ // ISO/IEC C++ working draft N4296 clarifies that insertion into an
+ // std::unordered_set does not invalidate memory locations of
+ // *values* inside the set (though it may invalidate existing
+ // iterators). In other words, one may safely point a StringPiece to
+ // a value in the set without that StringPiece being invalidated by
+ // future insertions.
+ static std::unordered_set<string>* PersistentStringStorage();
+
private:
- std::unordered_map<string, VariantShapeFn> shape_fns;
- std::unordered_map<string, VariantDecodeFn> decode_fns;
- // Map std::pair<device, type_name> to function.
- struct PairHash {
- template <typename T, typename U>
- std::size_t operator()(const std::pair<T, U>& x) const {
- return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
+ std::unordered_map<StringPiece, VariantShapeFn, StringPiece::Hasher>
+ shape_fns;
+ std::unordered_map<StringPiece, VariantDecodeFn, StringPiece::Hasher>
+ decode_fns;
+
+ // Map std::tuple<Op, device, type_name> to function.
+ struct TupleHash {
+ template <typename Op>
+ std::size_t operator()(
+ const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ // The hash of an enum is just its value as a std::size_t.
+ std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
+ StringPiece::Hasher sp_hasher;
+ ret = Hash64Combine(ret, sp_hasher(std::get<1>(x)));
+ ret = Hash64Combine(ret, sp_hasher(std::get<2>(x)));
+ return ret;
}
};
- std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash>
- zeros_like_fns;
+ std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
+ VariantUnaryOpFn, TupleHash>
+ unary_op_fns;
+ std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
+ VariantBinaryOpFn, TupleHash>
+ binary_op_fns;
+
+ // Find or insert a string into a persistent string storage
+ // container; return the StringPiece pointing to the permanent
+ // string location.
+ static StringPiece GetPersistentStringPiece(const string& str) {
+ const auto string_storage = PersistentStringStorage();
+ auto found = string_storage->find(str);
+ if (found == string_storage->end()) {
+ auto inserted = string_storage->insert(str);
+ return StringPiece(*inserted.first);
+ } else {
+ return StringPiece(*found);
+ }
+ }
};
// Gets a TensorShape from a Tensor containing a scalar Variant.
@@ -94,26 +157,57 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
//
bool DecodeUnaryVariant(Variant* variant);
-// Sets *z_out = zeros_like(v). The variant v must have a registered
-// ZerosLike function for the given Device. Returns an Internal error
-// if v does not have a registered zeros_like function for this device, or if
-// ZerosLike fails.
+// Sets *v_out = unary_op(v). The variant v must have a registered
+// UnaryOp function for the given Device. Returns an Internal error
+// if v does not have a registered unary_op function for this device, or if
+// UnaryOp fails.
//
// REQUIRES:
// v_out is not null.
//
template <typename Device>
-Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v,
- Variant* v_out) {
+Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
+ Variant* v_out) {
const string& device = DeviceName<Device>::value;
- UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn =
- UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName());
- if (zeros_like_fn == nullptr) {
+ UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ if (unary_op_fn == nullptr) {
return errors::Internal(
- "No unary variant zeros_like function found for Variant type_name: ",
- v.TypeName(), " for device type: ", device);
+ "No unary variant unary_op function found for unary variant op enum: ",
+ op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
}
- return (*zeros_like_fn)(ctx, v, v_out);
+ return (*unary_op_fn)(ctx, v, v_out);
+}
+
+// Sets *out = binary_op(a, b). The variants a and b must be the same type
+// and have a registered binary_op function for the given Device. Returns an
+// Internal error if a and b are not the same type_name or if
+// if a does not have a registered op function for this device, or if
+// BinaryOp fails.
+//
+// REQUIRES:
+// out is not null.
+//
+template <typename Device>
+Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
+ const Variant& a, const Variant& b, Variant* out) {
+ if (a.TypeName() != b.TypeName()) {
+ return errors::Internal(
+ "BianryOpVariants: Variants a and b have different "
+ "type names: '",
+ a.TypeName(), "' vs. '", b.TypeName(), "'");
+ }
+ const string& device = DeviceName<Device>::value;
+ UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ if (binary_op_fn == nullptr) {
+ return errors::Internal(
+ "No unary variant binary_op function found for binary variant op "
+ "enum: ",
+ op, " Variant type_name: '", a.TypeName(),
+ "' for device type: ", device);
+ }
+ return (*binary_op_fn)(ctx, a, b, out);
}
namespace variant_op_registry_fn_registration {
@@ -165,30 +259,65 @@ class UnaryVariantDecodeRegistration {
};
template <typename T>
-class UnaryVariantZerosLikeRegistration {
+class UnaryVariantUnaryOpRegistration {
typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
- LocalVariantZerosLikeFn;
+ LocalVariantUnaryOpFn;
public:
- UnaryVariantZerosLikeRegistration(
- const string& device, const string& type_name,
- const LocalVariantZerosLikeFn& zeros_like_fn) {
- auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx,
- const Variant& v,
- Variant* v_out) -> Status {
+ UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
+ const string& type_name,
+ const LocalVariantUnaryOpFn& unary_op_fn) {
+ auto wrapped_fn = [type_name, unary_op_fn](OpKernelContext* ctx,
+ const Variant& v,
+ Variant* v_out) -> Status {
CHECK_NOTNULL(v_out);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantZerosLikeFn: Could not access object, type_name: ",
+ "VariantUnaryOpFn: Could not access object, type_name: ",
type_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
- return zeros_like_fn(ctx, t, t_out);
+ return unary_op_fn(ctx, t, t_out);
+ };
+ UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(op, device, type_name,
+ wrapped_fn);
+ }
+};
+
+template <typename T>
+class UnaryVariantBinaryOpRegistration {
+ typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
+ T* out)>
+ LocalVariantBinaryOpFn;
+
+ public:
+ UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
+ const string& type_name,
+ const LocalVariantBinaryOpFn& binary_op_fn) {
+ auto wrapped_fn = [type_name, binary_op_fn](
+ OpKernelContext* ctx, const Variant& a,
+ const Variant& b, Variant* out) -> Status {
+ CHECK_NOTNULL(out);
+ *out = T();
+ if (a.get<T>() == nullptr) {
+ return errors::Internal(
+ "VariantBinaryOpFn: Could not access object 'a', type_name: ",
+ type_name);
+ }
+ if (b.get<T>() == nullptr) {
+ return errors::Internal(
+ "VariantBinaryOpFn: Could not access object 'b', type_name: ",
+ type_name);
+ }
+ const T& t_a = *a.get<T>();
+ const T& t_b = *b.get<T>();
+ T* t_out = out->get<T>();
+ return binary_op_fn(ctx, t_a, t_b, t_out);
};
- UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name,
- wrapped_fn);
+ UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(op, device, type_name,
+ wrapped_fn);
}
};
@@ -223,25 +352,47 @@ class UnaryVariantZerosLikeRegistration {
T> \
register_unary_variant_op_decoder_fn_##ctr(type_name)
-// Register a unary zeros_like variant function with the signature:
-// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device.
-#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \
- zeros_like_function) \
- REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, device, T, type_name, zeros_like_function)
-
-#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
- ctr, device, T, type_name, zeros_like_function) \
- REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \
- zeros_like_function)
-
-#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \
- ctr, device, T, type_name, zeros_like_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantZerosLikeRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(device, type_name, \
- zeros_like_function)
+// Register a unary unary_op variant function with the signature:
+// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
+// to Variants having TypeName type_name, for device string device,
+// for UnaryVariantOp enum op.
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, type_name, unary_op_function)
+
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_name, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
+ unary_op_function)
+
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_name, unary_op_function) \
+ static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
+ T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ unary_op_function)
+
+// Register a binary_op variant function with the signature:
+// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
+// to Variants having TypeName type_name, for device string device,
+// for BinaryVariantOp enum OP.
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, type_name, binary_op_function)
+
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_name, binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_name, binary_op_function)
+
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_name, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ binary_op_function)
} // end namespace tensorflow