aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:04:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 13:04:31 -0700
commitc898e63d07fc63315be98f0772736e5d7f2fb44c (patch)
treef250fcb1b35714d92afb371bf322f978ddcdfa6e /tensorflow/core/kernels
parent1084594657a5d139102ac794f84d1427a710e39a (diff)
parent937ad7c27f0d289067c935543d282e5ac5a310b1 (diff)
Merge pull request #22286 from Intel-tensorflow:nhasabni/unit-test-fixes
PiperOrigin-RevId: 214821528
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fc1c9003aa..fdb4c84c46 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,7 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ // We need to pass global op_registry as default_registry when creating
+ // graph. So that graph optimization passes can lookup all possible ops
+ // by name.
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ graph.get()->AddFunctionLibrary(global_flib.ToProto()));
CopyGraph(*fbody->graph, graph.get());
OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
@@ -250,9 +256,11 @@ class PartitionedCallOp : public AsyncOpKernel {
VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
<< partitions.size() << " shards.";
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ subgraph.get()->AddFunctionLibrary(global_flib.ToProto()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;