diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-09-17 21:01:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 21:05:50 -0700 |
commit | b91e27a9c33d038af79a0944eb9046b926d483c8 (patch) | |
tree | a80f832bfc33ff9dfd0e208ff4816b145a86a1cb /tensorflow/compiler/tf2xla | |
parent | eeb477cf661a16ee39e0621fd225d1f15859ffc8 (diff) |
Refactor out the metadata_ops set from const_analysis to a per-op bit; NFC
PiperOrigin-RevId: 213389224
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/const_analysis.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/shape_op.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_registry.cc | 24 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_registry.h | 12 |
4 files changed, 43 insertions, 13 deletions
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 922ae7c79a..027ca6d2d2 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g, std::vector<bool>* compile_time_const_arg_indices, std::vector<bool>* compile_time_const_nodes, std::function<bool(const Edge&)> edge_filter) { - // Operators that don't look at the data of their inputs, just the shapes. - const std::unordered_set<string> metadata_ops = { - "Rank", - "Shape", - "ShapeN", - "Size", - }; - std::vector<bool> compile_time_const_nodes_impl; if (compile_time_const_nodes) { CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids()); @@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g, if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. - if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return; + if (XlaOpRegistry::IsMetadataOp(node->type_string())) { + return; + } // If this node must be const, and it isn't a metadata op, then all of its // parents must be const. diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 2e0a69b70e..c8a0f31a03 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel { DataType out_dtype_; }; -REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp); +REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); class ShapeNOp : public XlaOpKernel { public: @@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel { private: DataType out_dtype_; }; -REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp); +REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); class RankOp : public XlaOpKernel { public: @@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp); +REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); class SizeOp : public XlaOpKernel { public: @@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp); +REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); class ExpandDimsOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index b0eeee3174..91d48125f1 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible compile time constant inputs."; return false; } + if (x.is_metadata_op != y.is_metadata_op) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible values for is_metadata_op."; + return false; + } return true; } @@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) { return &it->second.front()->compile_time_constant_inputs; } +/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end() || it->second.empty()) { + return false; + } + + // The test in IsCompatible ensures that if there are multiple matching + // registrations for this op name, they all have the same value of + // is_metadata_op, so only the first match is returned. + return it->second.front()->is_metadata_op; +} + std::vector<string> XlaOpRegistry::BackendNames() { std::vector<string> names; XlaOpRegistry& registry = Instance(); @@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() { + registration_->is_metadata_op = true; + return *this; +} + std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 34e22a4510..a4b624820a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -136,6 +136,10 @@ class XlaOpRegistry { static const std::unordered_set<string>* CompileTimeConstantInputs( const string& op); + // Returns true if `op` is a "metadata" op, one that only looks at the shapes + // of its operands and not their values. + static bool IsMetadataOp(const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -192,6 +196,10 @@ class XlaOpRegistry { // Names of arguments that must be compile-time constants. std::unordered_set<string> compile_time_constant_inputs; + // True if this is a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + bool is_metadata_op = false; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -256,6 +264,10 @@ class XlaOpRegistrationBuilder { // Mark 'input_name' as an argument whose value must be known at compile-time. XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name); + // Mark this op as a "metadata" op, one that only looks at the shapes of its + // operands and not their values. + XlaOpRegistrationBuilder& IsMetadataOp(); + std::unique_ptr<XlaOpRegistry::OpRegistration> Build( XlaOpRegistry::Factory factory); |