diff options
author | Michael Case <mikecase@google.com> | 2018-02-07 14:36:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-07 14:39:49 -0800 |
commit | d90054e7c0f41f4bab81df0548577a73b939a87a (patch) | |
tree | a15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/core/framework/variant_op_registry.h | |
parent | 8461760f9f6cde8ed97507484d2a879140141032 (diff) |
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/core/framework/variant_op_registry.h')
-rw-r--r-- | tensorflow/core/framework/variant_op_registry.h | 41 |
1 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 13f6908cae..e94100e994 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -166,6 +166,21 @@ class UnaryVariantOpRegistry { device_copy_fns; // Map std::tuple<Op, device, type_name> to function. + + // this breaks by falling victim to "too perfect forwarding" + // see https://stackoverflow.com/questions/44475317/variadic-template-issue + // and references therein + template <typename Op> + struct FuncTuple { + FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) + : op_type_(op), device_(dev), typename_(tname){}; + Op op_type_; + StringPiece device_, typename_; + }; + // friend declaration for operator== + // needed for clang + template <typename Op> + friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r); struct TupleHash { template <typename Op> std::size_t operator()( @@ -176,18 +191,25 @@ class UnaryVariantOpRegistry { ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); return ret; } + + template <typename Op> + std::size_t operator()(const FuncTuple<Op>& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast<std::size_t>(x.op_type_); + ret = Hash64Combine(ret, sp_hasher_(x.device_)); + ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + return ret; + } StringPieceHasher sp_hasher_; }; - std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>, - VariantUnaryOpFn, TupleHash> + std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>, - VariantBinaryOpFn, TupleHash> + std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage - // container; return the StringPiece pointing to the permanent - // string location. + // container; return the StringPiece pointing to the permanent string + // location. static StringPiece GetPersistentStringPiece(const string& str) { const auto string_storage = PersistentStringStorage(); auto found = string_storage->find(str); @@ -199,7 +221,12 @@ class UnaryVariantOpRegistry { } } }; - +template <typename Op> +inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs, + const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) { + return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && + (lhs.typename_ == rhs.typename_); +} // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape // function, or if it's a serialized Variant that cannot be decoded. |