aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 8f7333f1db..1ad3cbb4cb 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -134,14 +134,6 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
new_item->id = id;
new_item->graph = meta_graph.graph_def();
- // Optimize the graph (function inlining, l1 optimizations, etc).
- Status optimize_status =
- OptimizeGraph(meta_graph.graph_def(), &new_item->graph, cfg);
- if (!optimize_status.ok()) {
- LOG(ERROR) << "Function optimization failed: " << optimize_status;
- return nullptr;
- }
-
// Attempt to detect the fetch node(s).
if (meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
@@ -250,6 +242,10 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
*(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
}
+ // Erase the recorded result of any previous shape inference to start again
+ // from scratch.
+ node.mutable_attr()->erase("_output_shapes");
+
// Delete user specified placement if requested.
if (cfg.ignore_user_placement) {
node.clear_device();
@@ -329,6 +325,14 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
}
+ // Optimize the graph (function inlining, l1 optimizations, etc).
+ Status optimize_status =
+ OptimizeGraph(new_item->graph, &new_item->graph, cfg);
+ if (!optimize_status.ok()) {
+ LOG(ERROR) << "Function optimization failed: " << optimize_status;
+ return nullptr;
+ }
+
return new_item;
}