aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-09-17 21:01:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 21:05:50 -0700
commitb91e27a9c33d038af79a0944eb9046b926d483c8 (patch)
treea80f832bfc33ff9dfd0e208ff4816b145a86a1cb /tensorflow/compiler/tf2xla
parenteeb477cf661a16ee39e0621fd225d1f15859ffc8 (diff)
Refactor out the metadata_ops set from const_analysis to a per-op bit; NFC
PiperOrigin-RevId: 213389224
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc12
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.cc24
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_registry.h12
4 files changed, 43 insertions, 13 deletions
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index 922ae7c79a..027ca6d2d2 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -29,14 +29,6 @@ Status BackwardsConstAnalysis(const Graph& g,
std::vector<bool>* compile_time_const_arg_indices,
std::vector<bool>* compile_time_const_nodes,
std::function<bool(const Edge&)> edge_filter) {
- // Operators that don't look at the data of their inputs, just the shapes.
- const std::unordered_set<string> metadata_ops = {
- "Rank",
- "Shape",
- "ShapeN",
- "Size",
- };
-
std::vector<bool> compile_time_const_nodes_impl;
if (compile_time_const_nodes) {
CHECK_EQ(compile_time_const_nodes->size(), g.num_node_ids());
@@ -50,7 +42,9 @@ Status BackwardsConstAnalysis(const Graph& g,
if (!status.ok()) return;
// If this is a metadata-only op, don't propagate the const requirement.
- if (metadata_ops.find(node->type_string()) != metadata_ops.end()) return;
+ if (XlaOpRegistry::IsMetadataOp(node->type_string())) {
+ return;
+ }
// If this node must be const, and it isn't a metadata op, then all of its
// parents must be const.
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 2e0a69b70e..c8a0f31a03 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -44,7 +44,7 @@ class ShapeOp : public XlaOpKernel {
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("Shape").CompilationOnly(), ShapeOp);
+REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
@@ -66,7 +66,7 @@ class ShapeNOp : public XlaOpKernel {
private:
DataType out_dtype_;
};
-REGISTER_XLA_OP(Name("ShapeN").CompilationOnly(), ShapeNOp);
+REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp);
class RankOp : public XlaOpKernel {
public:
@@ -82,7 +82,7 @@ class RankOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Rank").CompilationOnly(), RankOp);
+REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp);
class SizeOp : public XlaOpKernel {
public:
@@ -101,7 +101,7 @@ class SizeOp : public XlaOpKernel {
}
};
-REGISTER_XLA_OP(Name("Size").CompilationOnly(), SizeOp);
+REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp);
class ExpandDimsOp : public XlaOpKernel {
public:
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index b0eeee3174..91d48125f1 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -90,6 +90,11 @@ XlaOpRegistry::~XlaOpRegistry() = default;
<< " have incompatible compile time constant inputs.";
return false;
}
+ if (x.is_metadata_op != y.is_metadata_op) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible values for is_metadata_op.";
+ return false;
+ }
return true;
}
@@ -350,6 +355,20 @@ XlaOpRegistry::CompileTimeConstantInputs(const string& op) {
return &it->second.front()->compile_time_constant_inputs;
}
+/*static*/ bool XlaOpRegistry::IsMetadataOp(const string& op) {
+ XlaOpRegistry& registry = Instance();
+ mutex_lock lock(registry.mutex_);
+ auto it = registry.ops_.find(op);
+ if (it == registry.ops_.end() || it->second.empty()) {
+ return false;
+ }
+
+ // The test in IsCompatible ensures that if there are multiple matching
+ // registrations for this op name, they all have the same value of
+ // is_metadata_op, so only the first match is returned.
+ return it->second.front()->is_metadata_op;
+}
+
std::vector<string> XlaOpRegistry::BackendNames() {
std::vector<string> names;
XlaOpRegistry& registry = Instance();
@@ -432,6 +451,11 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput(
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::IsMetadataOp() {
+ registration_->is_metadata_op = true;
+ return *this;
+}
+
std::unique_ptr<XlaOpRegistry::OpRegistration> XlaOpRegistrationBuilder::Build(
XlaOpRegistry::Factory factory) {
registration_->factory = factory;
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index 34e22a4510..a4b624820a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -136,6 +136,10 @@ class XlaOpRegistry {
static const std::unordered_set<string>* CompileTimeConstantInputs(
const string& op);
+ // Returns true if `op` is a "metadata" op, one that only looks at the shapes
+ // of its operands and not their values.
+ static bool IsMetadataOp(const string& op);
+
private:
friend class XlaBackendRegistrar;
friend class XlaOpRegistrar;
@@ -192,6 +196,10 @@ class XlaOpRegistry {
// Names of arguments that must be compile-time constants.
std::unordered_set<string> compile_time_constant_inputs;
+ // True if this is a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ bool is_metadata_op = false;
+
// Factory used to build OpKernels that perform symbolic execution.
Factory factory;
};
@@ -256,6 +264,10 @@ class XlaOpRegistrationBuilder {
// Mark 'input_name' as an argument whose value must be known at compile-time.
XlaOpRegistrationBuilder& CompileTimeConstInput(absl::string_view input_name);
+ // Mark this op as a "metadata" op, one that only looks at the shapes of its
+ // operands and not their values.
+ XlaOpRegistrationBuilder& IsMetadataOp();
+
std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
XlaOpRegistry::Factory factory);