diff options
author | 2018-09-10 11:04:38 -0700 | |
---|---|---|
committer | 2018-09-10 11:07:52 -0700 | |
commit | c5b14b334e89b9bcb0fd0199481318b8fdd65762 (patch) | |
tree | 8ad0aac1eb041c06b2040f9cf29ff990106c7c8f /tensorflow/compiler/jit | |
parent | 8a752ecd583846aa5b3157c4d9c2c7c654beb6fb (diff) |
Bug fix: consult graph's op registry to look up ops.
This is needed when the graph contains custom call ops. These functions are found only in the graph's registry and not the default one.
PiperOrigin-RevId: 212297305
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/jit/mark_for_compilation_pass_test.cc | 47 |
2 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 44caf0be52..e6cc6e52ae 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -443,7 +443,7 @@ Status FindCompilationCandidates( !registration->requires_compilation) { const OpDef* op_def; TF_RETURN_IF_ERROR( - OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def)); + graph.op_registry()->LookUpOpDef(node->type_string(), &op_def)); if (op_def->is_stateful()) { // We need to be able to constant fold the nodes in // compile_time_const_nodes given constant inputs (required by XLA) and diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 9473ac0a4c..c59770a4c8 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/array_ops.h" @@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) { EXPECT_EQ(clusters["shape"], ""); } +TEST(XlaCompilationTest, RandomShapeWithFunc) { + Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); + + FunctionDefLibrary flib_def; + FunctionDef func = FunctionDefHelper::Create( + /*function_name=*/"Stateful_func", /*in_def=*/{}, + /*out_def=*/{"out: int32"}, + /*attr_def*/ + {}, /*node_def=*/ + {FunctionDefHelper::Const("shape_shape", 2), + FunctionDefHelper::Const("minval", 1), + FunctionDefHelper::Const("maxval", 20), + {{"shape"}, + "RandomUniformInt", + {"shape_shape:output:0", "minval:output:0", "maxval:output:0"}, + {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}}, + /*ret_def=*/{{"out", "shape:output:0"}}); + + func.mutable_signature()->set_is_stateful(true); + *flib_def.add_function() = std::move(func); + TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def)); + NodeDef call_node; + call_node.set_name("fn_call"); + call_node.set_op("Stateful_func"); + Status status; + Node* call = root.graph()->AddNode(call_node, &status); + TF_ASSERT_OK(status); + + Output shape = Output(call, 0); + Output reshape_input = + ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT, + ops::Placeholder::Shape(TensorShape({500, 500}))); + Output reshape = + ops::Reshape(root.WithOpName("reshape"), reshape_input, shape); + + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(root.ToGraph(graph.get())); + auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), + flib_def); + TF_ASSERT_OK( + MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get())); + + std::unordered_map<string, string> clusters = GetClusters(*graph); + EXPECT_EQ(clusters["fn_call"], ""); +} + } // namespace } // namespace tensorflow |