diff options
author | Sanjoy Das <sanjoy@google.com> | 2018-09-27 14:01:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 14:06:03 -0700 |
commit | 5220e565b7cc32a5f757896c76c7d57c33bcd323 (patch) | |
tree | b1e6f265e1c6630caa57370db9bd2df0ddec4da0 /tensorflow/compiler | |
parent | 4cedc8b6e738b7a188c9c091cf667bacafae44b7 (diff) |
Don't use tensorflow::Edge after freeing it
Even with this bug we were accidentally doing the right thing (so the test case
doesn't actually fail without the fix): deleting an Edge sets its input and
output indices to kControlSlot-1 so we'd normally expect to fail when there is a
control edge out of the TF cluster (because a control edge would be recognized
as a data edge). But AddEdge(x, -1, y, -1) seems to do the right thing for both
control and data edges.
PiperOrigin-RevId: 214831204
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/jit/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/build_xla_ops_pass.cc | 11 | ||||
-rw-r--r-- | tensorflow/compiler/jit/build_xla_ops_pass_test.cc | 112 |
3 files changed, 116 insertions, 9 deletions
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 4e184729ef..5bf4af1014 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -478,6 +478,7 @@ tf_cc_test( name = "compilation_passes_test", size = "small", srcs = [ + "build_xla_ops_pass_test.cc", "encapsulate_subgraphs_pass_test.cc", "encapsulate_xla_computations_pass_test.cc", "mark_for_compilation_pass_test.cc", @@ -486,6 +487,7 @@ tf_cc_test( deps = [ ":common", ":compilation_passes", + ":node_matchers", ":xla_cluster_util", ":xla_gpu_device", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc index 13a518d0e8..9e3fd93cda 100644 --- a/tensorflow/compiler/jit/build_xla_ops_pass.cc +++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc @@ -112,16 +112,9 @@ static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) { std::vector<const Edge*> out_edges(old_node->out_edges().begin(), old_node->out_edges().end()); for (const Edge* edge : out_edges) { - Node* dst = edge->dst(); - int src_output = edge->src_output(); - int dst_input = edge->dst_input(); + // TODO(sanjoy): This does not update NodeDef inputs. + g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input()); g->RemoveEdge(edge); - - if (edge->IsControlEdge()) { - g->AddControlEdge(new_node, dst); - } else { - g->AddEdge(new_node, src_output, dst, dst_input); - } } } diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc new file mode 100644 index 0000000000..b7cb4506b9 --- /dev/null +++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/build_xla_ops_pass.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" +#include "tensorflow/compiler/jit/node_matchers.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using ::tensorflow::testing::FindNodeByName; +using ::tensorflow::testing::matchers::CtrlDeps; +using ::tensorflow::testing::matchers::NodeWith; +using ::tensorflow::testing::matchers::Op; + +Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) { + auto graph = absl::make_unique<Graph>(OpRegistry::Global()); + TF_RETURN_IF_ERROR(s.ToGraph(graph.get())); + + // Assign all nodes to the CPU device. + static const char* kCpuDevice = "/job:localhost/replica:0/task:0/cpu:0"; + for (Node* n : graph->nodes()) { + if (n->assigned_device_name().empty()) { + n->set_assigned_device_name(kCpuDevice); + } + } + + GraphOptimizationPassOptions opt_options; + opt_options.graph = &graph; + BuildXlaOpsPass pass; + TF_RETURN_IF_ERROR(pass.Run(opt_options)); + *result = std::move(graph); + return Status::OK(); +} + +Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name, + const string& node_name, Node** result) { + NodeDef call_node; + call_node.set_name(node_name); + call_node.set_op(callee_name); + AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node); + AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node); + AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node); + Status s; + *result = graph->AddNode(call_node, &s); + return s; +} + +Node* MakeWrite(const Scope& scope, const string& id) { + Output var_handle = + ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({})); + Output value_to_write = + ops::Const(scope.WithOpName("ValueToAssign" + id), 1.0f); + ops::AssignVariableOp assign_op(scope.WithOpName("Assignee" + id), var_handle, + value_to_write); + return assign_op.operation.node(); +} + +FunctionDefLibrary CreateFunctionDefLibWithConstFunction(const string& name) { + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/name, /*in_def=*/{}, /*out_def=*/{"out: float"}, + /*attr_def*/ + {}, /*node_def=*/{FunctionDefHelper::Const("one", 1.0f)}, + /*ret_def=*/{{"out", "out:output:0"}}); + *flib_def.add_function() = std::move(func); + return flib_def; +} + +TEST(BuildXlaOps, ControlDepsPreserved) { + Scope root = Scope::NewRootScope().ExitOnError(); + + FunctionDefLibrary flib_def = + CreateFunctionDefLibWithConstFunction("cluster_0"); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + Node* call; + TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call)); + Node* write_op = MakeWrite(root, "write"); + root.graph()->AddControlEdge(call, write_op); + + std::unique_ptr<Graph> graph; + TF_ASSERT_OK(BuildXlaOps(root, &graph)); + + Node* write_op_new = FindNodeByName(graph.get(), write_op->name()); + ASSERT_NE(write_op_new, nullptr); + EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun"))))); +} + +} // namespace +} // namespace tensorflow |