aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/variant_op_registry.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r--tensorflow/core/framework/variant_op_registry.h41
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index 13f6908cae..e94100e994 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -166,6 +166,21 @@ class UnaryVariantOpRegistry {
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
+
+ // this breaks by falling victim to "too perfect forwarding"
+ // see https://stackoverflow.com/questions/44475317/variadic-template-issue
+ // and references therein
+ template <typename Op>
+ struct FuncTuple {
+ FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
+ : op_type_(op), device_(dev), typename_(tname){};
+ Op op_type_;
+ StringPiece device_, typename_;
+ };
+ // friend declaration for operator==
+ // needed for clang
+ template <typename Op>
+ friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
struct TupleHash {
template <typename Op>
std::size_t operator()(
@@ -176,18 +191,25 @@ class UnaryVariantOpRegistry {
ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
return ret;
}
+
+ template <typename Op>
+ std::size_t operator()(const FuncTuple<Op>& x) const {
+ // The hash of an enum is just its value as a std::size_t.
+ std::size_t ret = static_cast<std::size_t>(x.op_type_);
+ ret = Hash64Combine(ret, sp_hasher_(x.device_));
+ ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ return ret;
+ }
StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
- VariantUnaryOpFn, TupleHash>
+ std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
- VariantBinaryOpFn, TupleHash>
+ std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
- // container; return the StringPiece pointing to the permanent
- // string location.
+ // container; return the StringPiece pointing to the permanent string
+ // location.
static StringPiece GetPersistentStringPiece(const string& str) {
const auto string_storage = PersistentStringStorage();
auto found = string_storage->find(str);
@@ -199,7 +221,12 @@ class UnaryVariantOpRegistry {
}
}
};
-
+template <typename Op>
+inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
+ const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
+ return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
+ (lhs.typename_ == rhs.typename_);
+}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
// function, or if it's a serialized Variant that cannot be decoded.