diff options
12 files changed, 164 insertions, 48 deletions
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index f02cb51038..f1edbbb602 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -50,6 +50,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", ], ) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index dd389de636..fb7e20fca0 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/grappler/costs/utils.h" +#include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { @@ -316,7 +317,11 @@ class SymbolicShapeRefiner { shape_inference::ShapeHandle shape) { return shape_refiner_->SetShape(node, output_port, shape); } - + Status SetUnknownShape(const Node* node, int output_port) { + shape_inference::ShapeHandle shape = + GetUnknownOutputShape(node, output_port); + return shape_refiner_->SetShape(node, output_port, shape); + } struct ShapeId { const Node* node; int port_id; @@ -646,6 +651,23 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, return Status::OK(); } +Status GraphProperties::OverwriteFedPorts( + SymbolicShapeRefiner* shape_refiner, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* node, TopoQueue* new_shapes) const { + auto it = fed_ports.find(node->name()); + Status status; + if (it != fed_ports.end()) { + // It is possible to feed node output ports with tensors of any shape: as a + // result, the shape of a fed port is completely unknown. + for (const int output_port : it->second) { + status.Update(shape_refiner->SetUnknownShape(node, output_port)); + } + new_shapes->push(node); + } + return status; +} + // Manually propagate the input shape for Enter nodes and update any Merge node // outputs. Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, @@ -673,9 +695,10 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner, return Status::OK(); } -Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner, - bool relax, const Node* n, - TopoQueue* new_shapes) { +Status GraphProperties::UpdateShapes( + SymbolicShapeRefiner* shape_refiner, bool relax, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* n, TopoQueue* new_shapes) const { if (n->IsEnter()) { // The Enter shape function always forwards an UnknownShape, so do the right // thing here. @@ -695,7 +718,9 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner, } } } - return Status::OK(); + // Nodes can be fed with any shape. The TensorFlow shape inference code can't + // handle this properly, so overwrite its behavior here. + return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes); } // Propagates the shapes in the transitive fan-out of <new_shapes>. @@ -703,6 +728,7 @@ Status GraphProperties::PropagateShapes( SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, const std::unordered_map<const Node*, std::unordered_set<const Node*>>& resources, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, int num_loops) const { // Limit the number of iterations to prevent infinite loops in the presence of // incorrect shape functions. The algoritm should converge in at most @@ -728,8 +754,8 @@ Status GraphProperties::PropagateShapes( for (const Edge* e : n->out_edges()) { if (!e->IsControlEdge()) { const Node* fanout = e->dst(); - TF_RETURN_IF_ERROR( - UpdateShapes(shape_refiner, relax, fanout, new_shapes)); + TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports, + fanout, new_shapes)); } } } @@ -803,7 +829,7 @@ Status GraphProperties::UpdateResource( return Status::OK(); } -Status GraphProperties::InferStatically() { +Status GraphProperties::InferStatically(bool assume_valid_feeds) { Graph graph(OpRegistry::Global()); FunctionLibraryDefinition function_library(graph.op_registry(), item_.graph.library()); @@ -820,11 +846,21 @@ Status GraphProperties::InferStatically() { Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); + std::unordered_map<string, std::unordered_set<int>> fed_ports; + if (!assume_valid_feeds) { + for (const auto& feed : item_.feed) { + int port_index = 0; + string node_name = ParseNodeName(feed.first, &port_index); + fed_ports[node_name].insert(port_index); + } + } + // List the resources and the nodes using them. Also collect the Enter and // Merge nodes. std::unordered_map<const Node*, std::unordered_set<const Node*>> resources; std::unordered_set<const Node*> enter_nodes; std::unordered_set<const Node*> merge_nodes; + std::unordered_set<const Node*> fed_nodes; int num_loops = 0; for (const Node* const node : graph.nodes()) { for (int i = 0; i < node->num_inputs(); ++i) { @@ -841,6 +877,9 @@ Status GraphProperties::InferStatically() { } else if (node->IsNextIteration()) { ++num_loops; } + if (fed_ports.find(node->name()) != fed_ports.end()) { + fed_nodes.insert(node); + } } SymbolicShapeRefiner refiner(&shape_refiner); @@ -855,15 +894,22 @@ Status GraphProperties::InferStatically() { // Force the propagation of shapes of Enter nodes manually (the Enter shape // function always forwards an UnknownShape). for (const Node* node : enter_nodes) { - TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes)); + TF_RETURN_IF_ERROR( + UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); } // Seed the propagation of shapes through merge nodes. for (const Node* node : merge_nodes) { - TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes)); + TF_RETURN_IF_ERROR( + UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes)); + } + // Also seed the propagation of shapes in the fanout of fed nodes. + for (const Node* node : fed_nodes) { + TF_RETURN_IF_ERROR( + OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes)); } // Propagate shapes normally. - TF_RETURN_IF_ERROR( - PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops)); + TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources, + fed_ports, num_loops)); } // Track shapes globally across the graph. @@ -874,6 +920,10 @@ Status GraphProperties::InferStatically() { if (!node_ctx) { continue; } + // Skip any information that comes from fed nodes. + if (fed_ports.find(node->name()) != fed_ports.end()) { + continue; + } for (const auto& merged_shapes : node_ctx->MergedShapes()) { if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second) .ok()) { diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index 95bc5044d0..6fc53a7f2e 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -34,12 +34,19 @@ class TopoQueue; // nodes, and potentially a set of nodes to feed. class GraphProperties { public: - // Factory method for creating a GrapplerShapes from a MetaGraphDef. - // Returns nullptr if the given meta_graph cannot be converted. explicit GraphProperties(const GrapplerItem& item) : item_(item) {} - Status InferStatically(); + // Infer the shapes through abstract interpretation. Feed information can be + // incorrect so it should be discarded to ensure correctness of the analysis. + // However, it can help infer shapes in the fanout of fed nodes (even though + // the correctness of these shapes can't be guaranteed), so in some cases + // (such as simulation or scheduling) it makes sense of keep these shapes. + Status InferStatically(bool assume_valid_feeds); + // Infer the shape by running the graph on the specified cluster and recording + // the shapes of the processed tensors. Status InferDynamically(Cluster* cluster); + // Extract the properties from a cost graph. For testing only since there is + // no way to ensure that the cost graph match the item. Status InferFromCostGraph(const CostGraphDef& cost_graph); // Stores `item_.graph` with the inferred output shapes to `output_graph_def`. @@ -65,12 +72,6 @@ class GraphProperties { OpInfo::TensorProperties*); private: - // Inputs - GrapplerItem item_; - std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_; - std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_; - const std::vector<OpInfo::TensorProperties> missing_properties_; - // Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into // <*queue_shapes_and_types>. static Status MergeEnqueueShapesAndTypes( @@ -99,17 +100,31 @@ class GraphProperties { static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner, const Node* node, bool relax, TopoQueue* new_shapes); + // Process a node that is used to feed the model. + Status OverwriteFedPorts( + SymbolicShapeRefiner* shape_refiner, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* node, TopoQueue* new_shapes) const; // Update the shapes for node 'n'. If output shapes for n have changed, // enqueue its fanout in 'new_shapes'. - static Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax, - const Node* n, TopoQueue* new_shapes); + Status UpdateShapes( + SymbolicShapeRefiner* shape_refiner, bool relax, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, + const Node* n, TopoQueue* new_shapes) const; // Propagate the shapes for the nodes enqueued in new_shapes and their // transitive fanout until a fixed point is reached. Status PropagateShapes( SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes, const std::unordered_map<const Node*, std::unordered_set<const Node*>>& resources, + const std::unordered_map<string, std::unordered_set<int>>& fed_ports, int num_loops) const; + + // Data members + GrapplerItem item_; + std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_; + std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_; + const std::vector<OpInfo::TensorProperties> missing_properties_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index c11af5777a..ad8e768f1f 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -73,7 +73,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) { CHECK(fake_input.NextItem(&item)); GraphProperties properties(item); - Status s = properties.InferStatically(); + Status s = properties.InferStatically(true); TF_CHECK_OK(s); for (const auto& node : item.graph.node()) { @@ -179,7 +179,7 @@ TEST_F(GraphPropertiesTest, Variables) { { GraphProperties static_properties(item); - TF_CHECK_OK(static_properties.InferStatically()); + TF_CHECK_OK(static_properties.InferStatically(false)); const auto props = static_properties.GetOutputProperties("Var"); EXPECT_EQ(1, props.size()); @@ -219,7 +219,7 @@ TEST_F(GraphPropertiesTest, VarHandles) { .Finalize(item.graph.add_node())); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("VarRead"); EXPECT_EQ(1, props.size()); @@ -286,7 +286,7 @@ TEST_F(GraphPropertiesTest, Queues) { TF_CHECK_OK(root.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props1 = properties.GetOutputProperties("Dequeue1"); ASSERT_EQ(1, props1.size()); @@ -335,7 +335,7 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) { "merge_without_loops.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"}; std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]", @@ -377,7 +377,7 @@ TEST_F(GraphPropertiesTest, WhileLoop) { "while_loop.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -435,7 +435,7 @@ TEST_F(GraphPropertiesTest, NestedLoop) { "nested_loop.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -498,7 +498,7 @@ TEST_F(GraphPropertiesTest, LoopsAndQueues) { "loops_and_queues.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -556,7 +556,7 @@ TEST_F(GraphPropertiesTest, LoopsAndResourceVars) { "loops_and_resource_vars.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -608,7 +608,7 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) { "queues_and_loops.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1", "while/Exit_1"}; @@ -657,7 +657,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) { item.fetch.push_back("init_restore"); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto restore_props = properties.GetOutputProperties("restore"); const OpInfo::TensorProperties& restore_prop = restore_props[0]; @@ -704,7 +704,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { item.fetch.push_back("init2"); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("restore"); const OpInfo::TensorProperties& prop = props[0]; @@ -732,7 +732,7 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) { "simple_function.pbtxt"); TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1"); const OpInfo::TensorProperties& prop = props[0]; EXPECT_EQ(DT_FLOAT, prop.dtype()); @@ -766,7 +766,7 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); const auto shape_c = properties.GetOutputProperties("c").at(0).shape(); EXPECT_EQ(2, shape_a.dim_size()); @@ -822,7 +822,7 @@ TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { GraphProperties properties(item); // This function should return OK, since it doesn't validate the colocation // constraints internally. - TF_EXPECT_OK(properties.InferStatically()); + TF_EXPECT_OK(properties.InferStatically(false)); } TEST_F(GraphPropertiesTest, ShapeTracking) { @@ -842,7 +842,7 @@ TEST_F(GraphPropertiesTest, ShapeTracking) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically()); + TF_CHECK_OK(properties.InferStatically(false)); const auto shape_a = properties.GetOutputProperties("a").at(0).shape(); const auto shape_b = properties.GetOutputProperties("b").at(0).shape(); const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape(); @@ -851,6 +851,56 @@ TEST_F(GraphPropertiesTest, ShapeTracking) { EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString()); } +TEST_F(GraphPropertiesTest, FedNodes) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + item.feed.emplace_back("AddN", Tensor()); + + { + // Conservative shape analysis: the shape of fed ports should be unknown + GraphProperties properties(item); + Status s = properties.InferStatically(false); + TF_CHECK_OK(s); + for (const auto& node : item.graph.node()) { + if (node.name() == "AddN") { + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + const OpInfo::TensorProperties& in_prop = in_props[0]; + EXPECT_EQ(DT_FLOAT, in_prop.dtype()); + EXPECT_FALSE(in_prop.shape().unknown_rank()); + EXPECT_EQ(2, in_prop.shape().dim_size()); + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + EXPECT_EQ(DT_FLOAT, in_prop.dtype()); + EXPECT_TRUE(in_prop.shape().unknown_rank()); + } + } + } + { + // Optimistic shape analysis: the shape of fed ports should be derived from + // the shape of the fanin. + GraphProperties properties(item); + Status s = properties.InferStatically(true); + TF_CHECK_OK(s); + for (const auto& node : item.graph.node()) { + if (node.name() == "AddN") { + const auto in_props = properties.GetInputProperties(node.name()); + EXPECT_EQ(1, in_props.size()); + const OpInfo::TensorProperties& in_prop = in_props[0]; + EXPECT_EQ(DT_FLOAT, in_prop.dtype()); + EXPECT_FALSE(in_prop.shape().unknown_rank()); + EXPECT_EQ(2, in_prop.shape().dim_size()); + const auto out_props = properties.GetOutputProperties(node.name()); + EXPECT_EQ(1, out_props.size()); + const OpInfo::TensorProperties& out_prop = out_props[0]; + EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString()); + } + } + } +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index e5e1ee3292..6640de668d 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -122,7 +122,7 @@ Status VirtualScheduler::Init() { // Construct graph properties. Status status; if (use_static_shapes_) { - status = graph_properties_.InferStatically(); + status = graph_properties_.InferStatically(true); } else { status = graph_properties_.InferDynamically(cluster_); } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 1e39c610a4..930d122234 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1067,7 +1067,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, if (opt_level_ == RewriterConfig::AGGRESSIVE) { graph_properties_.reset(new GraphProperties(item)); // Shapes are only needed in aggressive mode. - TF_RETURN_IF_ERROR(graph_properties_->InferStatically()); + TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false)); TF_RETURN_IF_ERROR( graph_properties_->AnnotateOutputShapes(optimized_graph_)); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index c77b2badf4..33a9dddba7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1163,7 +1163,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, Status s = errors::Unknown( "The graph properties are needed but were not initialized"); if (needs_shapes) { - s = properties.InferStatically(); + s = properties.InferStatically(false); } if (!has_feed && s.ok()) { diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index d5563e9d4c..1b8046b787 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1620,7 +1620,7 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, virtual_placer_.reset(new VirtualPlacer(cluster)); nodes_to_preserve_ = item.NodesToPreserve(); GraphProperties graph_properties(item); - auto status = graph_properties.InferStatically(); + auto status = graph_properties.InferStatically(false); if (!status.ok()) { *output = item.graph; return status; diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 7c44ce15c6..a2a2680c4f 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -716,7 +716,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, { // Estimate the size of the data to swap for each node. GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); for (auto& swap : nodes_to_swap) { const NodeDef* node = swap.first; std::vector<OpInfo::TensorProperties> props = diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index 6ce6deef2c..450e853407 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -86,7 +86,7 @@ Status EstimateEarliestExecutionTimes( name_map.clear(); GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster); @@ -154,7 +154,7 @@ Status EstimateRequiredTimes( } } GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(true)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster); diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index 7dd79f7c82..8f72a425c3 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -120,7 +120,7 @@ static PyObject* TF_GetOpProperties(GItem item) { Py_RETURN_NONE; } tensorflow::grappler::GraphProperties properties(*item); - tensorflow::Status status = properties.InferStatically(); + tensorflow::Status status = properties.InferStatically(false); if (!status.ok()) { Py_RETURN_NONE; } diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index 7d365c3be9..da5b03234e 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -27,7 +27,7 @@ ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} Status ModelAnalyzer::GenerateReport(std::ostream& os) { GraphProperties properties(item_); - TF_RETURN_IF_ERROR(properties.InferStatically()); + TF_RETURN_IF_ERROR(properties.InferStatically(false)); for (const auto& node : item_.MainOpsFanin()) { PrintNodeInfo(node, properties, os); |