aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 11:04:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 11:07:52 -0700
commitc5b14b334e89b9bcb0fd0199481318b8fdd65762 (patch)
tree8ad0aac1eb041c06b2040f9cf29ff990106c7c8f /tensorflow/compiler/jit
parent8a752ecd583846aa5b3157c4d9c2c7c654beb6fb (diff)
Bug fix: consult graph's op registry to look up ops.
This is needed when the graph contains custom call ops. These functions are found only in the graph's registry and not the default one. PiperOrigin-RevId: 212297305
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc2
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc47
2 files changed, 48 insertions, 1 deletions
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 44caf0be52..e6cc6e52ae 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -443,7 +443,7 @@ Status FindCompilationCandidates(
!registration->requires_compilation) {
const OpDef* op_def;
TF_RETURN_IF_ERROR(
- OpRegistry::Global()->LookUpOpDef(node->type_string(), &op_def));
+ graph.op_registry()->LookUpOpDef(node->type_string(), &op_def));
if (op_def->is_stateful()) {
// We need to be able to constant fold the nodes in
// compile_time_const_nodes given constant inputs (required by XLA) and
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 9473ac0a4c..c59770a4c8 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/array_ops.h"
@@ -847,5 +848,51 @@ TEST(XlaCompilationTest, RandomShape) {
EXPECT_EQ(clusters["shape"], "");
}
+TEST(XlaCompilationTest, RandomShapeWithFunc) {
+ Scope root = Scope::DisabledShapeInferenceScope().ExitOnError();
+
+ FunctionDefLibrary flib_def;
+ FunctionDef func = FunctionDefHelper::Create(
+ /*function_name=*/"Stateful_func", /*in_def=*/{},
+ /*out_def=*/{"out: int32"},
+ /*attr_def*/
+ {}, /*node_def=*/
+ {FunctionDefHelper::Const("shape_shape", 2),
+ FunctionDefHelper::Const("minval", 1),
+ FunctionDefHelper::Const("maxval", 20),
+ {{"shape"},
+ "RandomUniformInt",
+ {"shape_shape:output:0", "minval:output:0", "maxval:output:0"},
+ {{"Tout", DataType::DT_INT32}, {"T", DataType::DT_INT32}}}},
+ /*ret_def=*/{{"out", "shape:output:0"}});
+
+ func.mutable_signature()->set_is_stateful(true);
+ *flib_def.add_function() = std::move(func);
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ NodeDef call_node;
+ call_node.set_name("fn_call");
+ call_node.set_op("Stateful_func");
+ Status status;
+ Node* call = root.graph()->AddNode(call_node, &status);
+ TF_ASSERT_OK(status);
+
+ Output shape = Output(call, 0);
+ Output reshape_input =
+ ops::Placeholder(root.WithOpName("reshape_input"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({500, 500})));
+ Output reshape =
+ ops::Reshape(root.WithOpName("reshape"), reshape_input, shape);
+
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+ auto fld = absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(),
+ flib_def);
+ TF_ASSERT_OK(
+ MarkForCompilationPassTestHelper::MarkForCompilation(&graph, fld.get()));
+
+ std::unordered_map<string, string> clusters = GetClusters(*graph);
+ EXPECT_EQ(clusters["fn_call"], "");
+}
+
} // namespace
} // namespace tensorflow