diff options
-rw-r--r-- | tensorflow/core/grappler/grappler_item_builder.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 8f7333f1db..1ad3cbb4cb 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -134,14 +134,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( new_item->id = id; new_item->graph = meta_graph.graph_def(); - // Optimize the graph (function inlining, l1 optimizations, etc). - Status optimize_status = - OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg); - if (!optimize_status.ok()) { - LOG(ERROR) << "Function optimization failed: " << optimize_status; - return nullptr; - } - // Attempt to detect the fetch node(s). if (meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); @@ -250,6 +242,10 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto; } + // Erase the recorded result of any previous shape inference to start again + // from scratch. + node.mutable_attr()->erase("_output_shapes"); + // Delete user specified placement if requested. if (cfg.ignore_user_placement) { node.clear_device(); @@ -329,6 +325,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( } } + // Optimize the graph (function inlining, l1 optimizations, etc). + Status optimize_status = + OptimizeGraph(new_item->graph, &new_item->graph, cfg); + if (!optimize_status.ok()) { + LOG(ERROR) << "Function optimization failed: " << optimize_status; + return nullptr; + } + return new_item; } |