aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-08-23 20:11:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 20:18:21 -0700
commit2bcf90f28b6a71a13468f9deb1fcbb2867fb89ee (patch)
treea172b47f0af221ee3e0ec631191f63b5f4e1ff18 /tensorflow/compiler/jit
parent08924ec4e285a585ebbb3e8b9f60f3f8c310c103 (diff)
[tf2xla] Re-organize information about resource ops in one place; NFC
This is a cleanup on cr/208763036. Instead of spreading information about resource ops between jit/mark_for_compilation_pass and jit/resource_operation_safety_analysis we now have tf2xla/resource_operation_table own it. PiperOrigin-RevId: 210044178
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/BUILD3
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc26
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc162
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.h6
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc44
5 files changed, 48 insertions, 193 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index aa01b1fd52..df81f3c23e 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -316,12 +316,14 @@ cc_library(
hdrs = ["resource_operation_safety_analysis.h"],
deps = [
"//tensorflow/compiler/jit/graphcycles",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
],
)
@@ -381,6 +383,7 @@ cc_library(
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
+ "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 6cf04c7bdf..518c39ec15 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -86,35 +87,14 @@ bool HasResourceInput(const Node& node) {
DT_RESOURCE) != node.input_types().end();
}
-gtl::FlatSet<StringPiece>* GetNonResourceVarResourceOpSet() {
- gtl::FlatSet<StringPiece>* result = new gtl::FlatSet<StringPiece>;
-
- result->insert("StackCloseV2");
- result->insert("StackPopV2");
- result->insert("StackPushV2");
- result->insert("TensorArrayConcatV3");
- result->insert("TensorArrayGatherV3");
- result->insert("TensorArrayScatterV3");
- result->insert("TensorArrayGradV3");
- result->insert("TensorArrayCloseV3");
- result->insert("TensorArrayReadV3");
- result->insert("TensorArraySizeV3");
- result->insert("TensorArraySplitV3");
- result->insert("TensorArrayWriteV3");
-
- return result;
-}
-
// Returns true if `node` is a resource operation recognized by tf2xla that
// operates on something other than resource variables.
bool IsNonResourceVarResourceOp(const Node& node) {
// TODO(b/112837194): We can't cluster these because we only support
// snapshotting resource variables (and we can't e.g. snapshot stacks). This
// limitation may be fixable with some work.
- static gtl::FlatSet<StringPiece>* non_resource_var_resource_op =
- GetNonResourceVarResourceOpSet();
-
- return non_resource_var_resource_op->count(node.type_string());
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string());
+ return op_info && op_info->resource_kind() != XlaResourceKind::kVariable;
}
// Make sure we don't recurse infinitely on recursive functions.
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 2479e57cde..1ba4a5ef73 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -84,6 +84,8 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
@@ -94,82 +96,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Every TensorFlow operation is mapped to one of four kinds:
-enum class ResourceOpKind {
- kNone, // Has no interaction with resources.
- kRead, // Only reads from resources.
- kWrite, // Only writes to resources.
- kModify // Reads from and writes to resources.
-};
-
-StringPiece ResourceOpKindToString(ResourceOpKind op_kind) {
- switch (op_kind) {
- case ResourceOpKind::kRead:
- return "Read";
- case ResourceOpKind::kWrite:
- return "Write";
- case ResourceOpKind::kModify:
- return "Modify";
- case ResourceOpKind::kNone:
- return "None";
- }
-}
-
-// Returns a map that maps TensorFlow operation names to the corresponding
-// ResourceOpKind. We only care about XLA operations that we can cluster.
-gtl::FlatMap<StringPiece, ResourceOpKind>* GetResourceOpKindMap() {
- gtl::FlatMap<StringPiece, ResourceOpKind>* result =
- new gtl::FlatMap<StringPiece, ResourceOpKind>;
-
- result->insert({"AssignAddVariableOp", ResourceOpKind::kModify});
- result->insert({"AssignSubVariableOp", ResourceOpKind::kModify});
- result->insert({"AssignVariableOp", ResourceOpKind::kWrite});
- result->insert({"ReadVariableOp", ResourceOpKind::kRead});
- result->insert({"ResourceApplyAdaMax", ResourceOpKind::kModify});
- result->insert({"ResourceApplyAdadelta", ResourceOpKind::kModify});
- result->insert({"ResourceApplyAdagrad", ResourceOpKind::kModify});
- result->insert({"ResourceApplyAdagradDA", ResourceOpKind::kModify});
- result->insert({"ResourceApplyAdam", ResourceOpKind::kModify});
- result->insert({"ResourceApplyAddSign", ResourceOpKind::kModify});
- result->insert({"ResourceApplyCenteredRMSProp", ResourceOpKind::kModify});
- result->insert({"ResourceApplyFtrl", ResourceOpKind::kModify});
- result->insert({"ResourceApplyFtrlV2", ResourceOpKind::kModify});
- result->insert({"ResourceApplyGradientDescent", ResourceOpKind::kModify});
- result->insert({"ResourceApplyMomentum", ResourceOpKind::kModify});
- result->insert({"ResourceApplyPowerSign", ResourceOpKind::kModify});
- result->insert({"ResourceApplyProximalAdagrad", ResourceOpKind::kModify});
- result->insert(
- {"ResourceApplyProximalGradientDescent", ResourceOpKind::kModify});
- result->insert({"ResourceApplyRMSProp", ResourceOpKind::kModify});
- result->insert({"ResourceGather", ResourceOpKind::kRead});
- result->insert({"ResourceScatterAdd", ResourceOpKind::kModify});
- result->insert({"ResourceScatterDiv", ResourceOpKind::kModify});
- result->insert({"ResourceScatterMax", ResourceOpKind::kModify});
- result->insert({"ResourceScatterMin", ResourceOpKind::kModify});
- result->insert({"ResourceScatterMul", ResourceOpKind::kModify});
- result->insert({"ResourceScatterNdAdd", ResourceOpKind::kModify});
- result->insert({"ResourceScatterNdUpdate", ResourceOpKind::kModify});
- result->insert({"ResourceScatterSub", ResourceOpKind::kModify});
- result->insert({"ResourceScatterUpdate", ResourceOpKind::kModify});
- result->insert({"ResourceStridedSliceAssign", ResourceOpKind::kModify});
- result->insert({"StackCloseV2", ResourceOpKind::kRead}); // Reads shape
- result->insert({"StackPopV2", ResourceOpKind::kModify});
- result->insert({"StackPushV2", ResourceOpKind::kModify});
- result->insert({"TensorArrayConcatV3", ResourceOpKind::kRead});
- result->insert({"TensorArrayGatherV3", ResourceOpKind::kRead});
- result->insert({"TensorArrayScatterV3", ResourceOpKind::kWrite});
- result->insert({"TensorArrayGradV3", ResourceOpKind::kRead}); // Reads shape
- result->insert({"TensorArrayCloseV3", ResourceOpKind::kRead}); // Reads shape
- result->insert({"TensorArrayReadV3", ResourceOpKind::kRead});
- result->insert({"TensorArraySizeV3", ResourceOpKind::kRead});
- result->insert({"TensorArraySplitV3", ResourceOpKind::kWrite});
- result->insert({"TensorArrayWriteV3", ResourceOpKind::kWrite});
- result->insert({"VarIsInitializedOp", ResourceOpKind::kRead});
- result->insert({"VariableShape", ResourceOpKind::kRead});
-
- return result;
-}
-
// Returns true if `n` may call a function.
Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
bool* out_result) {
@@ -186,26 +112,25 @@ Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def,
return Status::OK();
}
-// Maps `n` to the ResourceOpKind corresponding to its operation.
-Status ResourceOpKindForNode(
+// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is
+// not a resource operation recognized by XLA then sets `out_resource_op_kind`
+// to nullopt.
+Status XlaResourceOpKindForNode(
const Node& n, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
- ResourceOpKind* out_resource_op_kind) {
- static const gtl::FlatMap<StringPiece, ResourceOpKind>& resource_op_kind_map =
- *GetResourceOpKindMap();
-
+ absl::optional<XlaResourceOpKind>* out_resource_op_kind) {
bool should_ignore = false;
if (resource_ops_to_ignore) {
TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore));
}
if (should_ignore) {
- *out_resource_op_kind = ResourceOpKind::kNone;
+ *out_resource_op_kind = absl::nullopt;
return Status::OK();
}
- auto it = resource_op_kind_map.find(n.type_string());
- if (it != resource_op_kind_map.end()) {
- *out_resource_op_kind = it->second;
+ const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string());
+ if (op_info) {
+ *out_resource_op_kind = op_info->kind();
return Status::OK();
}
@@ -214,8 +139,11 @@ Status ResourceOpKindForNode(
// inter-procedural analysis.
bool may_call_function;
TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function));
- *out_resource_op_kind =
- may_call_function ? ResourceOpKind::kModify : ResourceOpKind::kNone;
+ if (may_call_function) {
+ *out_resource_op_kind = XlaResourceOpKind::kReadWrite;
+ } else {
+ *out_resource_op_kind = absl::nullopt;
+ }
return Status::OK();
}
@@ -224,22 +152,22 @@ Status ResourceOpKindForNode(
// resource op kind `from` to a TensorFlow operation of resource op kind `to`
// can be represented by an XLA cluster and needs no special handling around
// auto-jit.
-bool IsEdgeSafe(ResourceOpKind from, ResourceOpKind to) {
+bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) {
// XLA clusters forces all reads to happen before all writes, which means the
// kinds of edges it can faithfully represent are: Read->Write, Read->Modify,
// Modify->Write, Read->Read, Write->Write.
//
// TODO(b/112856632): We can, in theory, support Read->Read and Write->Write
// dependencies.
- return from == ResourceOpKind::kNone || to == ResourceOpKind::kNone ||
- (from == ResourceOpKind::kRead && to == ResourceOpKind::kWrite);
+ return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite;
}
-using ResourceOp = std::pair<int, ResourceOpKind>;
+using ResourceOp = std::pair<int, XlaResourceOpKind>;
string ResourceOpToString(const ResourceOp& resource_op) {
- return strings::StrCat(resource_op.first, ": ",
- ResourceOpKindToString(resource_op.second));
+ return strings::StrCat(
+ resource_op.first, ": ",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second));
}
// A copy-on-write set used to store the set of ResourceOps reaching a node in a
@@ -332,9 +260,10 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) {
return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}");
}
-string NodeToString(const Node& n, ResourceOpKind resource_op_kind) {
- return strings::StrCat("[", n.name(), ": ", n.type_string(), "(",
- ResourceOpKindToString(resource_op_kind), ")", "]");
+string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) {
+ return strings::StrCat(
+ "[", n.name(), ": ", n.type_string(), "(",
+ XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]");
}
} // namespace
@@ -356,14 +285,17 @@ Status ComputeIncompatibleResourceOperationPairs(
const bool vlog = VLOG_IS_ON(2);
for (Node* n : rpo) {
- ResourceOpKind op_kind;
- TF_RETURN_IF_ERROR(
- ResourceOpKindForNode(*n, flib_def, resource_ops_to_ignore, &op_kind));
+ absl::optional<XlaResourceOpKind> op_kind;
+ TF_RETURN_IF_ERROR(XlaResourceOpKindForNode(
+ *n, flib_def, resource_ops_to_ignore, &op_kind));
ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()];
+ // Merge the reaching resource operations for all the incoming edges to
+ // create the set of all possible resource ops reaching `n`.
for (const Edge* e : n->in_edges()) {
if (n->IsMerge() && e->src()->IsNextIteration()) {
+ // Ignore back-edges (see file comment).
continue;
}
@@ -372,21 +304,23 @@ Status ComputeIncompatibleResourceOperationPairs(
resource_op_set->Add(incoming_op_set);
}
- for (ResourceOp incoming_op : *resource_op_set) {
- if (!IsEdgeSafe(incoming_op.second, op_kind)) {
+ // Add to the "incompatible resource ops" set if necessary.
+ if (op_kind) {
+ for (ResourceOp incoming_op : *resource_op_set) {
+ if (IsEdgeSafe(incoming_op.second, *op_kind)) {
+ continue;
+ }
+
if (vlog) {
VLOG(2) << "Unsafe edge: "
<< NodeToString(*g.FindNodeId(incoming_op.first),
incoming_op.second)
- << " -> " << NodeToString(*n, op_kind);
+ << " -> " << NodeToString(*n, *op_kind);
}
result->push_back({incoming_op.first, n->id()});
}
- }
- if (op_kind != ResourceOpKind::kNone) {
- // This check is an optimization, not necessary for correctness.
- resource_op_set->Add({n->id(), op_kind});
+ resource_op_set->Add({n->id(), *op_kind});
}
if (vlog) {
@@ -399,18 +333,4 @@ Status ComputeIncompatibleResourceOperationPairs(
return Status::OK();
}
-
-namespace resource_op_safety_analysis_internal {
-std::vector<string> GetKnownResourceOperations() {
- std::unique_ptr<gtl::FlatMap<StringPiece, ResourceOpKind>>
- resource_op_kind_map(GetResourceOpKindMap());
-
- std::vector<string> result;
- for (const auto& name_kind_map : *resource_op_kind_map) {
- result.push_back(string(name_kind_map.first));
- }
- std::sort(result.begin(), result.end());
- return result;
-}
-} // namespace resource_op_safety_analysis_internal
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
index 147039348b..ae8cfeecad 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h
@@ -68,12 +68,6 @@ Status ComputeIncompatibleResourceOperationPairs(
const Graph& g, const FunctionLibraryDefinition* flib_def,
const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore,
std::vector<std::pair<int, int>>* result);
-
-namespace resource_op_safety_analysis_internal {
-// Internal API for whitebox testing.
-std::vector<string> GetKnownResourceOperations();
-} // namespace resource_op_safety_analysis_internal
-
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
index c774bdf5ff..e54b547abc 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
-#include "absl/strings/str_join.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
@@ -34,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -536,47 +536,5 @@ TEST(ResourceOperationSafetyAnalysisTest, Loop) {
bool IsResourceArgDef(const OpDef::ArgDef& arg_def) {
return arg_def.type() == DT_RESOURCE;
}
-
-bool HasResourceInputOrOutput(const OpDef& op_def) {
- return std::any_of(op_def.input_arg().begin(), op_def.input_arg().end(),
- IsResourceArgDef) ||
- std::any_of(op_def.output_arg().begin(), op_def.output_arg().end(),
- IsResourceArgDef);
-}
-
-TEST(ResourceOperationSafetyAnalysisTest, HaveAllResourceOps) {
- gtl::FlatMap<string, bool> known_resource_ops;
- for (const string& known_resource_op :
- resource_op_safety_analysis_internal::GetKnownResourceOperations()) {
- ASSERT_TRUE(known_resource_ops.insert({known_resource_op, false}).second);
- }
-
- std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps();
- for (const string& xla_op_name : xla_op_names) {
- if (xla_op_name == "StackV2" || xla_op_name == "TensorArrayV3") {
- continue;
- }
-
- const OpDef* op_def;
- TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def));
- if (HasResourceInputOrOutput(*op_def)) {
- EXPECT_EQ(known_resource_ops.count(xla_op_name), 1)
- << "Unknown resource op " << xla_op_name;
- known_resource_ops[xla_op_name] = true;
- }
- }
-
- std::vector<string> unnecessary_resource_ops;
- for (const auto& pair : known_resource_ops) {
- if (!pair.second) {
- unnecessary_resource_ops.push_back(pair.first);
- }
- }
-
- EXPECT_TRUE(unnecessary_resource_ops.empty())
- << "Stale resource ops:\n"
- << absl::StrJoin(unnecessary_resource_ops, "\n");
-}
-
} // namespace
} // namespace tensorflow