diff options
author | 2018-09-27 13:04:31 -0700 | |
---|---|---|
committer | 2018-09-27 13:04:31 -0700 | |
commit | c898e63d07fc63315be98f0772736e5d7f2fb44c (patch) | |
tree | f250fcb1b35714d92afb371bf322f978ddcdfa6e /tensorflow/core/kernels | |
parent | 1084594657a5d139102ac794f84d1427a710e39a (diff) | |
parent | 937ad7c27f0d289067c935543d282e5ac5a310b1 (diff) |
Merge pull request #22286 from Intel-tensorflow:nhasabni/unit-test-fixes
PiperOrigin-RevId: 214821528
Diffstat (limited to 'tensorflow/core/kernels')
-rw-r--r-- | tensorflow/core/kernels/partitioned_function_ops.cc | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index fc1c9003aa..fdb4c84c46 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -97,7 +97,13 @@ class PartitionedCallOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, fbody != nullptr, errors::Internal("Could not find handle ", handle), done); + // We need to pass global op_registry as default_registry when creating + // graph. So that graph optimization passes can lookup all possible ops + // by name. auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def()); + FunctionLibraryDefinition global_flib(OpRegistry::Global(), {}); + TF_CHECK_OK( + graph.get()->AddFunctionLibrary(global_flib.ToProto())); CopyGraph(*fbody->graph, graph.get()); OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done); @@ -250,9 +256,11 @@ class PartitionedCallOp : public AsyncOpKernel { VLOG(3) << "Partitioned function '" << func_.name() << "', yielding " << partitions.size() << " shards."; - const FunctionLibraryDefinition* flib_def = &graph->flib_def(); for (const auto& partition : partitions) { - std::unique_ptr<Graph> subgraph(new Graph(flib_def)); + std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def())); + FunctionLibraryDefinition global_flib(OpRegistry::Global(), {}); + TF_CHECK_OK( + subgraph.get()->AddFunctionLibrary(global_flib.ToProto())); GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = true; |