diff options
9 files changed, 341 insertions, 193 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index aa01b1fd52..df81f3c23e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -316,12 +316,14 @@ cc_library( hdrs = ["resource_operation_safety_analysis.h"], deps = [ "//tensorflow/compiler/jit/graphcycles", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/core:framework", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", ], ) @@ -381,6 +383,7 @@ cc_library( "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:xla_ops", "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 6cf04c7bdf..518c39ec15 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_cluster_util.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/function.h" @@ -86,35 +87,14 @@ bool HasResourceInput(const Node& node) { DT_RESOURCE) != node.input_types().end(); } -gtl::FlatSet<StringPiece>* GetNonResourceVarResourceOpSet() { - gtl::FlatSet<StringPiece>* result = new gtl::FlatSet<StringPiece>; - - result->insert("StackCloseV2"); - result->insert("StackPopV2"); - result->insert("StackPushV2"); - result->insert("TensorArrayConcatV3"); - result->insert("TensorArrayGatherV3"); - result->insert("TensorArrayScatterV3"); - result->insert("TensorArrayGradV3"); - result->insert("TensorArrayCloseV3"); - result->insert("TensorArrayReadV3"); - result->insert("TensorArraySizeV3"); - result->insert("TensorArraySplitV3"); - result->insert("TensorArrayWriteV3"); - - return result; -} - // Returns true if `node` is a resource operation recognized by tf2xla that // operates on something other than resource variables. bool IsNonResourceVarResourceOp(const Node& node) { // TODO(b/112837194): We can't cluster these because we only support // snapshotting resource variables (and we can't e.g. snapshot stacks). This // limitation may be fixable with some work. - static gtl::FlatSet<StringPiece>* non_resource_var_resource_op = - GetNonResourceVarResourceOpSet(); - - return non_resource_var_resource_op->count(node.type_string()); + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(node.type_string()); + return op_info && op_info->resource_kind() != XlaResourceKind::kVariable; } // Make sure we don't recurse infinitely on recursive functions. diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc index 2479e57cde..1ba4a5ef73 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc @@ -84,6 +84,8 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_join.h" +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/tensor_id.h" @@ -94,82 +96,6 @@ limitations under the License. namespace tensorflow { namespace { -// Every TensorFlow operation is mapped to one of four kinds: -enum class ResourceOpKind { - kNone, // Has no interaction with resources. - kRead, // Only reads from resources. - kWrite, // Only writes to resources. - kModify // Reads from and writes to resources. -}; - -StringPiece ResourceOpKindToString(ResourceOpKind op_kind) { - switch (op_kind) { - case ResourceOpKind::kRead: - return "Read"; - case ResourceOpKind::kWrite: - return "Write"; - case ResourceOpKind::kModify: - return "Modify"; - case ResourceOpKind::kNone: - return "None"; - } -} - -// Returns a map that maps TensorFlow operation names to the corresponding -// ResourceOpKind. We only care about XLA operations that we can cluster. -gtl::FlatMap<StringPiece, ResourceOpKind>* GetResourceOpKindMap() { - gtl::FlatMap<StringPiece, ResourceOpKind>* result = - new gtl::FlatMap<StringPiece, ResourceOpKind>; - - result->insert({"AssignAddVariableOp", ResourceOpKind::kModify}); - result->insert({"AssignSubVariableOp", ResourceOpKind::kModify}); - result->insert({"AssignVariableOp", ResourceOpKind::kWrite}); - result->insert({"ReadVariableOp", ResourceOpKind::kRead}); - result->insert({"ResourceApplyAdaMax", ResourceOpKind::kModify}); - result->insert({"ResourceApplyAdadelta", ResourceOpKind::kModify}); - result->insert({"ResourceApplyAdagrad", ResourceOpKind::kModify}); - result->insert({"ResourceApplyAdagradDA", ResourceOpKind::kModify}); - result->insert({"ResourceApplyAdam", ResourceOpKind::kModify}); - result->insert({"ResourceApplyAddSign", ResourceOpKind::kModify}); - result->insert({"ResourceApplyCenteredRMSProp", ResourceOpKind::kModify}); - result->insert({"ResourceApplyFtrl", ResourceOpKind::kModify}); - result->insert({"ResourceApplyFtrlV2", ResourceOpKind::kModify}); - result->insert({"ResourceApplyGradientDescent", ResourceOpKind::kModify}); - result->insert({"ResourceApplyMomentum", ResourceOpKind::kModify}); - result->insert({"ResourceApplyPowerSign", ResourceOpKind::kModify}); - result->insert({"ResourceApplyProximalAdagrad", ResourceOpKind::kModify}); - result->insert( - {"ResourceApplyProximalGradientDescent", ResourceOpKind::kModify}); - result->insert({"ResourceApplyRMSProp", ResourceOpKind::kModify}); - result->insert({"ResourceGather", ResourceOpKind::kRead}); - result->insert({"ResourceScatterAdd", ResourceOpKind::kModify}); - result->insert({"ResourceScatterDiv", ResourceOpKind::kModify}); - result->insert({"ResourceScatterMax", ResourceOpKind::kModify}); - result->insert({"ResourceScatterMin", ResourceOpKind::kModify}); - result->insert({"ResourceScatterMul", ResourceOpKind::kModify}); - result->insert({"ResourceScatterNdAdd", ResourceOpKind::kModify}); - result->insert({"ResourceScatterNdUpdate", ResourceOpKind::kModify}); - result->insert({"ResourceScatterSub", ResourceOpKind::kModify}); - result->insert({"ResourceScatterUpdate", ResourceOpKind::kModify}); - result->insert({"ResourceStridedSliceAssign", ResourceOpKind::kModify}); - result->insert({"StackCloseV2", ResourceOpKind::kRead}); // Reads shape - result->insert({"StackPopV2", ResourceOpKind::kModify}); - result->insert({"StackPushV2", ResourceOpKind::kModify}); - result->insert({"TensorArrayConcatV3", ResourceOpKind::kRead}); - result->insert({"TensorArrayGatherV3", ResourceOpKind::kRead}); - result->insert({"TensorArrayScatterV3", ResourceOpKind::kWrite}); - result->insert({"TensorArrayGradV3", ResourceOpKind::kRead}); // Reads shape - result->insert({"TensorArrayCloseV3", ResourceOpKind::kRead}); // Reads shape - result->insert({"TensorArrayReadV3", ResourceOpKind::kRead}); - result->insert({"TensorArraySizeV3", ResourceOpKind::kRead}); - result->insert({"TensorArraySplitV3", ResourceOpKind::kWrite}); - result->insert({"TensorArrayWriteV3", ResourceOpKind::kWrite}); - result->insert({"VarIsInitializedOp", ResourceOpKind::kRead}); - result->insert({"VariableShape", ResourceOpKind::kRead}); - - return result; -} - // Returns true if `n` may call a function. Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, bool* out_result) { @@ -186,26 +112,25 @@ Status MayCallFunction(const Node& n, const FunctionLibraryDefinition* flib_def, return Status::OK(); } -// Maps `n` to the ResourceOpKind corresponding to its operation. -Status ResourceOpKindForNode( +// Maps `n` to the XlaResourceOpKind corresponding to its operation. If `n` is +// not a resource operation recognized by XLA then sets `out_resource_op_kind` +// to nullopt. +Status XlaResourceOpKindForNode( const Node& n, const FunctionLibraryDefinition* flib_def, const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, - ResourceOpKind* out_resource_op_kind) { - static const gtl::FlatMap<StringPiece, ResourceOpKind>& resource_op_kind_map = - *GetResourceOpKindMap(); - + absl::optional<XlaResourceOpKind>* out_resource_op_kind) { bool should_ignore = false; if (resource_ops_to_ignore) { TF_RETURN_IF_ERROR(resource_ops_to_ignore(n, &should_ignore)); } if (should_ignore) { - *out_resource_op_kind = ResourceOpKind::kNone; + *out_resource_op_kind = absl::nullopt; return Status::OK(); } - auto it = resource_op_kind_map.find(n.type_string()); - if (it != resource_op_kind_map.end()) { - *out_resource_op_kind = it->second; + const XlaResourceOpInfo* op_info = GetResourceOpInfoForOp(n.type_string()); + if (op_info) { + *out_resource_op_kind = op_info->kind(); return Status::OK(); } @@ -214,8 +139,11 @@ Status ResourceOpKindForNode( // inter-procedural analysis. bool may_call_function; TF_RETURN_IF_ERROR(MayCallFunction(n, flib_def, &may_call_function)); - *out_resource_op_kind = - may_call_function ? ResourceOpKind::kModify : ResourceOpKind::kNone; + if (may_call_function) { + *out_resource_op_kind = XlaResourceOpKind::kReadWrite; + } else { + *out_resource_op_kind = absl::nullopt; + } return Status::OK(); } @@ -224,22 +152,22 @@ Status ResourceOpKindForNode( // resource op kind `from` to a TensorFlow operation of resource op kind `to` // can be represented by an XLA cluster and needs no special handling around // auto-jit. -bool IsEdgeSafe(ResourceOpKind from, ResourceOpKind to) { +bool IsEdgeSafe(XlaResourceOpKind from, XlaResourceOpKind to) { // XLA clusters forces all reads to happen before all writes, which means the // kinds of edges it can faithfully represent are: Read->Write, Read->Modify, // Modify->Write, Read->Read, Write->Write. // // TODO(b/112856632): We can, in theory, support Read->Read and Write->Write // dependencies. - return from == ResourceOpKind::kNone || to == ResourceOpKind::kNone || - (from == ResourceOpKind::kRead && to == ResourceOpKind::kWrite); + return from == XlaResourceOpKind::kRead && to == XlaResourceOpKind::kWrite; } -using ResourceOp = std::pair<int, ResourceOpKind>; +using ResourceOp = std::pair<int, XlaResourceOpKind>; string ResourceOpToString(const ResourceOp& resource_op) { - return strings::StrCat(resource_op.first, ": ", - ResourceOpKindToString(resource_op.second)); + return strings::StrCat( + resource_op.first, ": ", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op.second)); } // A copy-on-write set used to store the set of ResourceOps reaching a node in a @@ -332,9 +260,10 @@ string ResourceOpSetToString(const ResourceOpSet& resource_op_set) { return strings::StrCat("{", absl::StrJoin(elements_debug_string, ","), "}"); } -string NodeToString(const Node& n, ResourceOpKind resource_op_kind) { - return strings::StrCat("[", n.name(), ": ", n.type_string(), "(", - ResourceOpKindToString(resource_op_kind), ")", "]"); +string NodeToString(const Node& n, XlaResourceOpKind resource_op_kind) { + return strings::StrCat( + "[", n.name(), ": ", n.type_string(), "(", + XlaResourceOpInfo::XlaResourceOpKindToString(resource_op_kind), ")", "]"); } } // namespace @@ -356,14 +285,17 @@ Status ComputeIncompatibleResourceOperationPairs( const bool vlog = VLOG_IS_ON(2); for (Node* n : rpo) { - ResourceOpKind op_kind; - TF_RETURN_IF_ERROR( - ResourceOpKindForNode(*n, flib_def, resource_ops_to_ignore, &op_kind)); + absl::optional<XlaResourceOpKind> op_kind; + TF_RETURN_IF_ERROR(XlaResourceOpKindForNode( + *n, flib_def, resource_ops_to_ignore, &op_kind)); ResourceOpSet* resource_op_set = &resource_op_set_for_node[n->id()]; + // Merge the reaching resource operations for all the incoming edges to + // create the set of all possible resource ops reaching `n`. for (const Edge* e : n->in_edges()) { if (n->IsMerge() && e->src()->IsNextIteration()) { + // Ignore back-edges (see file comment). continue; } @@ -372,21 +304,23 @@ Status ComputeIncompatibleResourceOperationPairs( resource_op_set->Add(incoming_op_set); } - for (ResourceOp incoming_op : *resource_op_set) { - if (!IsEdgeSafe(incoming_op.second, op_kind)) { + // Add to the "incompatible resource ops" set if necessary. + if (op_kind) { + for (ResourceOp incoming_op : *resource_op_set) { + if (IsEdgeSafe(incoming_op.second, *op_kind)) { + continue; + } + if (vlog) { VLOG(2) << "Unsafe edge: " << NodeToString(*g.FindNodeId(incoming_op.first), incoming_op.second) - << " -> " << NodeToString(*n, op_kind); + << " -> " << NodeToString(*n, *op_kind); } result->push_back({incoming_op.first, n->id()}); } - } - if (op_kind != ResourceOpKind::kNone) { - // This check is an optimization, not necessary for correctness. - resource_op_set->Add({n->id(), op_kind}); + resource_op_set->Add({n->id(), *op_kind}); } if (vlog) { @@ -399,18 +333,4 @@ Status ComputeIncompatibleResourceOperationPairs( return Status::OK(); } - -namespace resource_op_safety_analysis_internal { -std::vector<string> GetKnownResourceOperations() { - std::unique_ptr<gtl::FlatMap<StringPiece, ResourceOpKind>> - resource_op_kind_map(GetResourceOpKindMap()); - - std::vector<string> result; - for (const auto& name_kind_map : *resource_op_kind_map) { - result.push_back(string(name_kind_map.first)); - } - std::sort(result.begin(), result.end()); - return result; -} -} // namespace resource_op_safety_analysis_internal } // namespace tensorflow diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.h b/tensorflow/compiler/jit/resource_operation_safety_analysis.h index 147039348b..ae8cfeecad 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis.h +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.h @@ -68,12 +68,6 @@ Status ComputeIncompatibleResourceOperationPairs( const Graph& g, const FunctionLibraryDefinition* flib_def, const std::function<Status(const Node&, bool*)>& resource_ops_to_ignore, std::vector<std::pair<int, int>>* result); - -namespace resource_op_safety_analysis_internal { -// Internal API for whitebox testing. -std::vector<string> GetKnownResourceOperations(); -} // namespace resource_op_safety_analysis_internal - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_RESOURCE_OPERATION_SAFETY_ANALYSIS_H_ diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc index c774bdf5ff..e54b547abc 100644 --- a/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc +++ b/tensorflow/compiler/jit/resource_operation_safety_analysis_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/jit/resource_operation_safety_analysis.h" -#include "absl/strings/str_join.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/control_flow_ops_internal.h" @@ -34,6 +33,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -536,47 +536,5 @@ TEST(ResourceOperationSafetyAnalysisTest, Loop) { bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { return arg_def.type() == DT_RESOURCE; } - -bool HasResourceInputOrOutput(const OpDef& op_def) { - return std::any_of(op_def.input_arg().begin(), op_def.input_arg().end(), - IsResourceArgDef) || - std::any_of(op_def.output_arg().begin(), op_def.output_arg().end(), - IsResourceArgDef); -} - -TEST(ResourceOperationSafetyAnalysisTest, HaveAllResourceOps) { - gtl::FlatMap<string, bool> known_resource_ops; - for (const string& known_resource_op : - resource_op_safety_analysis_internal::GetKnownResourceOperations()) { - ASSERT_TRUE(known_resource_ops.insert({known_resource_op, false}).second); - } - - std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); - for (const string& xla_op_name : xla_op_names) { - if (xla_op_name == "StackV2" || xla_op_name == "TensorArrayV3") { - continue; - } - - const OpDef* op_def; - TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); - if (HasResourceInputOrOutput(*op_def)) { - EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) - << "Unknown resource op " << xla_op_name; - known_resource_ops[xla_op_name] = true; - } - } - - std::vector<string> unnecessary_resource_ops; - for (const auto& pair : known_resource_ops) { - if (!pair.second) { - unnecessary_resource_ops.push_back(pair.first); - } - } - - EXPECT_TRUE(unnecessary_resource_ops.empty()) - << "Stale resource ops:\n" - << absl::StrJoin(unnecessary_resource_ops, "\n"); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index b67e717f82..92e577bb7b 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -601,3 +601,30 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "resource_operation_table", + srcs = ["resource_operation_table.cc"], + hdrs = ["resource_operation_table.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/algorithm:container", + ], +) + +tf_cc_test( + name = "resource_operation_table_test", + srcs = ["resource_operation_table_test.cc"], + deps = [ + ":resource_operation_table", + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc new file mode 100644 index 0000000000..32ba6df2e6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" +#include "absl/algorithm/container.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +/*static*/ StringPiece XlaResourceOpInfo::XlaResourceOpKindToString( + XlaResourceOpKind op_kind) { + switch (op_kind) { + case XlaResourceOpKind::kRead: + return "Read"; + case XlaResourceOpKind::kWrite: + return "Write"; + case XlaResourceOpKind::kReadWrite: + return "Modify"; + } +} + +static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* CreateResourceOpInfoMap() { + gtl::FlatMap<StringPiece, XlaResourceOpInfo>* result = + new gtl::FlatMap<StringPiece, XlaResourceOpInfo>; + + auto add = [&](StringPiece op, XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) { + auto insert_result = + result->insert({op, XlaResourceOpInfo(op_kind, resource_kind)}); + CHECK(insert_result.second); + }; + + auto kRead = XlaResourceOpKind::kRead; + auto kWrite = XlaResourceOpKind::kWrite; + auto kReadWrite = XlaResourceOpKind::kReadWrite; + + auto kVariable = XlaResourceKind::kVariable; + auto kStack = XlaResourceKind::kStack; + auto kTensorArray = XlaResourceKind::kTensorArray; + + // clang-format off + add("AssignAddVariableOp" , kReadWrite, kVariable); + add("AssignSubVariableOp" , kReadWrite, kVariable); + add("AssignVariableOp" , kWrite, kVariable); + add("ReadVariableOp" , kRead, kVariable); + add("ResourceApplyAdaMax" , kReadWrite, kVariable); + add("ResourceApplyAdadelta" , kReadWrite, kVariable); + add("ResourceApplyAdagrad" , kReadWrite, kVariable); + add("ResourceApplyAdagradDA" , kReadWrite, kVariable); + add("ResourceApplyAdam" , kReadWrite, kVariable); + add("ResourceApplyAddSign" , kReadWrite, kVariable); + add("ResourceApplyCenteredRMSProp" , kReadWrite, kVariable); + add("ResourceApplyFtrl" , kReadWrite, kVariable); + add("ResourceApplyFtrlV2" , kReadWrite, kVariable); + add("ResourceApplyGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyMomentum" , kReadWrite, kVariable); + add("ResourceApplyPowerSign" , kReadWrite, kVariable); + add("ResourceApplyProximalAdagrad" , kReadWrite, kVariable); + add("ResourceApplyProximalGradientDescent" , kReadWrite, kVariable); + add("ResourceApplyRMSProp" , kReadWrite, kVariable); + add("ResourceGather" , kRead, kVariable); + add("ResourceScatterAdd" , kReadWrite, kVariable); + add("ResourceScatterDiv" , kReadWrite, kVariable); + add("ResourceScatterMax" , kReadWrite, kVariable); + add("ResourceScatterMin" , kReadWrite, kVariable); + add("ResourceScatterMul" , kReadWrite, kVariable); + add("ResourceScatterNdAdd" , kReadWrite, kVariable); + add("ResourceScatterNdUpdate" , kReadWrite, kVariable); + add("ResourceScatterSub" , kReadWrite, kVariable); + add("ResourceScatterUpdate" , kReadWrite, kVariable); + add("ResourceStridedSliceAssign" , kReadWrite, kVariable); + add("VarIsInitializedOp" , kRead, kVariable); + add("VariableShape" , kRead, kVariable); + + add("StackV2" , kWrite, kStack); + add("StackCloseV2" , kRead, kStack); + add("StackPopV2" , kReadWrite, kStack); + add("StackPushV2" , kReadWrite, kStack); + + add("TensorArrayV3" , kWrite, kTensorArray); + add("TensorArrayConcatV3" , kRead, kTensorArray); + add("TensorArrayGatherV3" , kRead, kTensorArray); + add("TensorArrayScatterV3" , kWrite, kTensorArray); + add("TensorArrayGradV3" , kRead, kTensorArray); + add("TensorArrayCloseV3" , kRead, kTensorArray); + add("TensorArrayReadV3" , kRead, kTensorArray); + add("TensorArraySizeV3" , kRead, kTensorArray); + add("TensorArraySplitV3" , kWrite, kTensorArray); + add("TensorArrayWriteV3" , kWrite, kTensorArray); + // clang-format on + + return result; +} + +static const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& +GetStaticResourceOpInfoMap() { + static gtl::FlatMap<StringPiece, XlaResourceOpInfo>* op_info_map = + CreateResourceOpInfoMap(); + return *op_info_map; +} + +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op) { + const gtl::FlatMap<StringPiece, XlaResourceOpInfo>& op_infos = + GetStaticResourceOpInfoMap(); + auto it = op_infos.find(op); + return it == op_infos.end() ? nullptr : &it->second; +} + +namespace resource_op_table_internal { +std::vector<StringPiece> GetKnownResourceOps() { + std::vector<StringPiece> result; + for (const auto& p : GetStaticResourceOpInfoMap()) { + result.push_back(p.first); + } + absl::c_sort(result); + return result; +} +} // namespace resource_op_table_internal +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.h b/tensorflow/compiler/tf2xla/resource_operation_table.h new file mode 100644 index 0000000000..eb3b98334b --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table.h @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ + +#include <string> + +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" + +// Exposes information about the resource operations supported by tf2xla in a +// structured form. + +namespace tensorflow { +enum class XlaResourceOpKind { + kRead, // Only reads from resources. + kWrite, // Only writes to resources. + kReadWrite // Reads from and writes to resources. +}; + +enum class XlaResourceKind { + kVariable, // Operates on resource variables. + kStack, // Operates on stacks. + kTensorArray // Operates on tensor arrays. +}; + +class XlaResourceOpInfo { + public: + explicit XlaResourceOpInfo(XlaResourceOpKind op_kind, + XlaResourceKind resource_kind) + : op_kind_(op_kind), resource_kind_(resource_kind) {} + + XlaResourceOpKind kind() const { return op_kind_; } + XlaResourceKind resource_kind() const { return resource_kind_; } + + static StringPiece XlaResourceOpKindToString(XlaResourceOpKind op_kind); + + private: + XlaResourceOpKind op_kind_; + XlaResourceKind resource_kind_; +}; + +// Returns a XlaResourceOpInfo describing `op` if it is a resource operation +// supported by tf2xla, otherwise returns null (i.e. if this returns null then +// `op` is either not a resource operation or is unsupported by XLA). +const XlaResourceOpInfo* GetResourceOpInfoForOp(StringPiece op); + +namespace resource_op_table_internal { +// NB! Implementation detail exposed for unit testing, do not use. +// +// Returns the set of resource operations known by this module. +std::vector<StringPiece> GetKnownResourceOps(); +} // namespace resource_op_table_internal + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_OPERATION_TABLE_H_ diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc new file mode 100644 index 0000000000..0343f80de9 --- /dev/null +++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/resource_operation_table.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +bool IsResourceArgDef(const OpDef::ArgDef& arg_def) { + return arg_def.type() == DT_RESOURCE; +} + +bool HasResourceInputOrOutput(const OpDef& op_def) { + return absl::c_any_of(op_def.input_arg(), IsResourceArgDef) || + absl::c_any_of(op_def.output_arg(), IsResourceArgDef); +} + +TEST(ResourceOperationTableTest, HaveAllResourceOps) { + gtl::FlatMap<string, bool> known_resource_ops; + for (StringPiece known_resource_op : + resource_op_table_internal::GetKnownResourceOps()) { + ASSERT_TRUE( + known_resource_ops.insert({string(known_resource_op), false}).second); + } + + std::vector<string> xla_op_names = XlaOpRegistry::GetAllRegisteredOps(); + for (const string& xla_op_name : xla_op_names) { + const OpDef* op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef(xla_op_name, &op_def)); + if (HasResourceInputOrOutput(*op_def)) { + EXPECT_EQ(known_resource_ops.count(xla_op_name), 1) + << "Unknown resource op " << xla_op_name; + known_resource_ops[xla_op_name] = true; + } + } + + std::vector<string> unnecessary_resource_ops; + for (const auto& pair : known_resource_ops) { + if (!pair.second) { + unnecessary_resource_ops.push_back(pair.first); + } + } + + EXPECT_TRUE(unnecessary_resource_ops.empty()) + << "Stale resource ops:\n" + << absl::StrJoin(unnecessary_resource_ops, "\n"); +} +} // namespace +} // namespace tensorflow |