diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.cc | 50 |
1 files changed, 38 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6710ff9df3..d24e7e8ee4 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -429,18 +429,22 @@ class SymbolicShapeRefiner { // perform shape inference on the function body. // // Propagate shape information of final function body node - // to function node `node`. + // to function node `function_node`. // - // In the event of an error, UpdateNode will simply set `node`'s + // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* node) { - auto it = fun_to_grappler_function_item_.find(node->op()); + Status UpdateFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - node->op(), " was not previously added to SymbolicShapeRefiner."); + function_node->op(), + " was not previously added to SymbolicShapeRefiner."); } - GrapplerFunctionItem& grappler_function_item = it->second; + // Copy (not reference) so that changes we make here (e.g., replacing + // Placeholder with Const) don't affect one in + // fun_to_grappler_function_item_. + GrapplerFunctionItem grappler_function_item = it->second; GraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. @@ -453,7 +457,7 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = node->input(i); + const string& input = function_node->input(i); const string& node_name = NodeName(input); if (IsControlInput(input)) { @@ -478,16 +482,35 @@ class SymbolicShapeRefiner { TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); input_inference_context->ShapeHandleToProto(handle, &proto); + // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. + for (int i = 0; i < proto.dim_size(); i++) { + if (proto.dim(i).size() < -1) { + proto.mutable_dim(i)->set_size(-1); + } + } *attr_output_shape.mutable_shape() = proto; (*fun_node->mutable_attr())["shape"] = attr_output_shape; } + // Replace input Placeholders with Consts, if values are known. Note that + // we don't check exceptions here as it's done in the above loop. + for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { + const string& input = function_node->input(i); + const string& node_name = NodeName(input); + NodeDef* input_node = graph_.GetNode(node_name); + // TODO(dyoon): also use Const when output_tensors_as_shape is available. + if (IsConstant(*input_node)) { + TF_CHECK_OK( + ReplaceInputWithConst(*input_node, i, &grappler_function_item)); + } + } + // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true)); // Add return nodes for output shapes. - auto ic = GetContext(node); + auto ic = GetContext(function_node); int output = 0; for (auto const& out_arg : grappler_function_item.outputs()) { if (out_arg.output_tensors.size() > 1) { @@ -505,8 +528,9 @@ class SymbolicShapeRefiner { const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { - return errors::FailedPrecondition("Unable to find return node ", - node_name, " for ", node->name()); + return errors::FailedPrecondition( + "Unable to find return function_node ", node_name, " for ", + function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); @@ -671,11 +695,13 @@ class SymbolicShapeRefiner { // true, as the updates to the call node will have changed, even if it's // the same function being called twice with the same input shapes. // Example: simple_function.pbtxt - if (UpdateFunction(node).ok()) { + auto s = UpdateFunction(node); + if (s.ok()) { return Status::OK(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() - << ". Defaulting to ShapeUnknown."; + << ". Defaulting to ShapeUnknown.\n" + << s.ToString(); } } |