aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc74
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h37
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc82
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/static_schedule.cc4
-rw-r--r--tensorflow/python/grappler/item.i2
-rw-r--r--tensorflow/python/grappler/model_analyzer.cc2
12 files changed, 164 insertions, 48 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index f02cb51038..f1edbbb602 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index dd389de636..fb7e20fca0 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
@@ -316,7 +317,11 @@ class SymbolicShapeRefiner {
shape_inference::ShapeHandle shape) {
return shape_refiner_->SetShape(node, output_port, shape);
}
-
+ Status SetUnknownShape(const Node* node, int output_port) {
+ shape_inference::ShapeHandle shape =
+ GetUnknownOutputShape(node, output_port);
+ return shape_refiner_->SetShape(node, output_port, shape);
+ }
struct ShapeId {
const Node* node;
int port_id;
@@ -646,6 +651,23 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
+Status GraphProperties::OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const {
+ auto it = fed_ports.find(node->name());
+ Status status;
+ if (it != fed_ports.end()) {
+ // It is possible to feed node output ports with tensors of any shape: as a
+ // result, the shape of a fed port is completely unknown.
+ for (const int output_port : it->second) {
+ status.Update(shape_refiner->SetUnknownShape(node, output_port));
+ }
+ new_shapes->push(node);
+ }
+ return status;
+}
+
// Manually propagate the input shape for Enter nodes and update any Merge node
// outputs.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
@@ -673,9 +695,10 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
-Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
- bool relax, const Node* n,
- TopoQueue* new_shapes) {
+Status GraphProperties::UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const {
if (n->IsEnter()) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
@@ -695,7 +718,9 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
}
}
}
- return Status::OK();
+ // Nodes can be fed with any shape. The TensorFlow shape inference code can't
+ // handle this properly, so overwrite its behavior here.
+ return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes);
}
// Propagates the shapes in the transitive fan-out of <new_shapes>.
@@ -703,6 +728,7 @@ Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
@@ -728,8 +754,8 @@ Status GraphProperties::PropagateShapes(
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
const Node* fanout = e->dst();
- TF_RETURN_IF_ERROR(
- UpdateShapes(shape_refiner, relax, fanout, new_shapes));
+ TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports,
+ fanout, new_shapes));
}
}
}
@@ -803,7 +829,7 @@ Status GraphProperties::UpdateResource(
return Status::OK();
}
-Status GraphProperties::InferStatically() {
+Status GraphProperties::InferStatically(bool assume_valid_feeds) {
Graph graph(OpRegistry::Global());
FunctionLibraryDefinition function_library(graph.op_registry(),
item_.graph.library());
@@ -820,11 +846,21 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
+ std::unordered_map<string, std::unordered_set<int>> fed_ports;
+ if (!assume_valid_feeds) {
+ for (const auto& feed : item_.feed) {
+ int port_index = 0;
+ string node_name = ParseNodeName(feed.first, &port_index);
+ fed_ports[node_name].insert(port_index);
+ }
+ }
+
// List the resources and the nodes using them. Also collect the Enter and
// Merge nodes.
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes;
+ std::unordered_set<const Node*> fed_nodes;
int num_loops = 0;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
@@ -841,6 +877,9 @@ Status GraphProperties::InferStatically() {
} else if (node->IsNextIteration()) {
++num_loops;
}
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ fed_nodes.insert(node);
+ }
}
SymbolicShapeRefiner refiner(&shape_refiner);
@@ -855,15 +894,22 @@ Status GraphProperties::InferStatically() {
// Force the propagation of shapes of Enter nodes manually (the Enter shape
// function always forwards an UnknownShape).
for (const Node* node : enter_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
}
// Seed the propagation of shapes through merge nodes.
for (const Node* node : merge_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
+ }
+ // Also seed the propagation of shapes in the fanout of fed nodes.
+ for (const Node* node : fed_nodes) {
+ TF_RETURN_IF_ERROR(
+ OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes));
}
// Propagate shapes normally.
- TF_RETURN_IF_ERROR(
- PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
+ TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources,
+ fed_ports, num_loops));
}
// Track shapes globally across the graph.
@@ -874,6 +920,10 @@ Status GraphProperties::InferStatically() {
if (!node_ctx) {
continue;
}
+ // Skip any information that comes from fed nodes.
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ continue;
+ }
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second)
.ok()) {
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 95bc5044d0..6fc53a7f2e 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -34,12 +34,19 @@ class TopoQueue;
// nodes, and potentially a set of nodes to feed.
class GraphProperties {
public:
- // Factory method for creating a GrapplerShapes from a MetaGraphDef.
- // Returns nullptr if the given meta_graph cannot be converted.
explicit GraphProperties(const GrapplerItem& item) : item_(item) {}
- Status InferStatically();
+ // Infer the shapes through abstract interpretation. Feed information can be
+ // incorrect so it should be discarded to ensure correctness of the analysis.
+ // However, it can help infer shapes in the fanout of fed nodes (even though
+ // the correctness of these shapes can't be guaranteed), so in some cases
+ // (such as simulation or scheduling) it makes sense of keep these shapes.
+ Status InferStatically(bool assume_valid_feeds);
+ // Infer the shape by running the graph on the specified cluster and recording
+ // the shapes of the processed tensors.
Status InferDynamically(Cluster* cluster);
+ // Extract the properties from a cost graph. For testing only since there is
+ // no way to ensure that the cost graph match the item.
Status InferFromCostGraph(const CostGraphDef& cost_graph);
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
@@ -65,12 +72,6 @@ class GraphProperties {
OpInfo::TensorProperties*);
private:
- // Inputs
- GrapplerItem item_;
- std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
- std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
- const std::vector<OpInfo::TensorProperties> missing_properties_;
-
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status MergeEnqueueShapesAndTypes(
@@ -99,17 +100,31 @@ class GraphProperties {
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
const Node* node, bool relax,
TopoQueue* new_shapes);
+ // Process a node that is used to feed the model.
+ Status OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const;
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
- static Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax,
- const Node* n, TopoQueue* new_shapes);
+ Status UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const;
+
+ // Data members
+ GrapplerItem item_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
+ const std::vector<OpInfo::TensorProperties> missing_properties_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index c11af5777a..ad8e768f1f 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -73,7 +73,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
CHECK(fake_input.NextItem(&item));
GraphProperties properties(item);
- Status s = properties.InferStatically();
+ Status s = properties.InferStatically(true);
TF_CHECK_OK(s);
for (const auto& node : item.graph.node()) {
@@ -179,7 +179,7 @@ TEST_F(GraphPropertiesTest, Variables) {
{
GraphProperties static_properties(item);
- TF_CHECK_OK(static_properties.InferStatically());
+ TF_CHECK_OK(static_properties.InferStatically(false));
const auto props = static_properties.GetOutputProperties("Var");
EXPECT_EQ(1, props.size());
@@ -219,7 +219,7 @@ TEST_F(GraphPropertiesTest, VarHandles) {
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("VarRead");
EXPECT_EQ(1, props.size());
@@ -286,7 +286,7 @@ TEST_F(GraphPropertiesTest, Queues) {
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props1 = properties.GetOutputProperties("Dequeue1");
ASSERT_EQ(1, props1.size());
@@ -335,7 +335,7 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
"merge_without_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
@@ -377,7 +377,7 @@ TEST_F(GraphPropertiesTest, WhileLoop) {
"while_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -435,7 +435,7 @@ TEST_F(GraphPropertiesTest, NestedLoop) {
"nested_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -498,7 +498,7 @@ TEST_F(GraphPropertiesTest, LoopsAndQueues) {
"loops_and_queues.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -556,7 +556,7 @@ TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
"loops_and_resource_vars.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -608,7 +608,7 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) {
"queues_and_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -657,7 +657,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
item.fetch.push_back("init_restore");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto restore_props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& restore_prop = restore_props[0];
@@ -704,7 +704,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
item.fetch.push_back("init2");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& prop = props[0];
@@ -732,7 +732,7 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
"simple_function.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1");
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
@@ -766,7 +766,7 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
EXPECT_EQ(2, shape_a.dim_size());
@@ -822,7 +822,7 @@ TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
GraphProperties properties(item);
// This function should return OK, since it doesn't validate the colocation
// constraints internally.
- TF_EXPECT_OK(properties.InferStatically());
+ TF_EXPECT_OK(properties.InferStatically(false));
}
TEST_F(GraphPropertiesTest, ShapeTracking) {
@@ -842,7 +842,7 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
@@ -851,6 +851,56 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
}
+TEST_F(GraphPropertiesTest, FedNodes) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+ item.feed.emplace_back("AddN", Tensor());
+
+ {
+ // Conservative shape analysis: the shape of fed ports should be unknown
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(false);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.name() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_TRUE(in_prop.shape().unknown_rank());
+ }
+ }
+ }
+ {
+ // Optimistic shape analysis: the shape of fed ports should be derived from
+ // the shape of the fanin.
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(true);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.name() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
+ }
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index e5e1ee3292..6640de668d 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -122,7 +122,7 @@ Status VirtualScheduler::Init() {
// Construct graph properties.
Status status;
if (use_static_shapes_) {
- status = graph_properties_.InferStatically();
+ status = graph_properties_.InferStatically(true);
} else {
status = graph_properties_.InferDynamically(cluster_);
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 1e39c610a4..930d122234 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -1067,7 +1067,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
if (opt_level_ == RewriterConfig::AGGRESSIVE) {
graph_properties_.reset(new GraphProperties(item));
// Shapes are only needed in aggressive mode.
- TF_RETURN_IF_ERROR(graph_properties_->InferStatically());
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
TF_RETURN_IF_ERROR(
graph_properties_->AnnotateOutputShapes(optimized_graph_));
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index c77b2badf4..33a9dddba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -1163,7 +1163,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
Status s = errors::Unknown(
"The graph properties are needed but were not initialized");
if (needs_shapes) {
- s = properties.InferStatically();
+ s = properties.InferStatically(false);
}
if (!has_feed && s.ok()) {
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index d5563e9d4c..1b8046b787 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1620,7 +1620,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
virtual_placer_.reset(new VirtualPlacer(cluster));
nodes_to_preserve_ = item.NodesToPreserve();
GraphProperties graph_properties(item);
- auto status = graph_properties.InferStatically();
+ auto status = graph_properties.InferStatically(false);
if (!status.ok()) {
*output = item.graph;
return status;
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 7c44ce15c6..a2a2680c4f 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -716,7 +716,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
{
// Estimate the size of the data to swap for each node.
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
for (auto& swap : nodes_to_swap) {
const NodeDef* node = swap.first;
std::vector<OpInfo::TensorProperties> props =
diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc
index 6ce6deef2c..450e853407 100644
--- a/tensorflow/core/grappler/optimizers/static_schedule.cc
+++ b/tensorflow/core/grappler/optimizers/static_schedule.cc
@@ -86,7 +86,7 @@ Status EstimateEarliestExecutionTimes(
name_map.clear();
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
@@ -154,7 +154,7 @@ Status EstimateRequiredTimes(
}
}
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i
index 7dd79f7c82..8f72a425c3 100644
--- a/tensorflow/python/grappler/item.i
+++ b/tensorflow/python/grappler/item.i
@@ -120,7 +120,7 @@ static PyObject* TF_GetOpProperties(GItem item) {
Py_RETURN_NONE;
}
tensorflow::grappler::GraphProperties properties(*item);
- tensorflow::Status status = properties.InferStatically();
+ tensorflow::Status status = properties.InferStatically(false);
if (!status.ok()) {
Py_RETURN_NONE;
}
diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc
index 7d365c3be9..da5b03234e 100644
--- a/tensorflow/python/grappler/model_analyzer.cc
+++ b/tensorflow/python/grappler/model_analyzer.cc
@@ -27,7 +27,7 @@ ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {}
Status ModelAnalyzer::GenerateReport(std::ostream& os) {
GraphProperties properties(item_);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
for (const auto& node : item_.MainOpsFanin()) {
PrintNodeInfo(node, properties, os);