aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc50
1 files changed, 38 insertions, 12 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 6710ff9df3..d24e7e8ee4 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -429,18 +429,22 @@ class SymbolicShapeRefiner {
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
- // to function node `node`.
+ // to function node `function_node`.
//
- // In the event of an error, UpdateNode will simply set `node`'s
+ // In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
- Status UpdateFunction(const NodeDef* node) {
- auto it = fun_to_grappler_function_item_.find(node->op());
+ Status UpdateFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
- node->op(), " was not previously added to SymbolicShapeRefiner.");
+ function_node->op(),
+ " was not previously added to SymbolicShapeRefiner.");
}
- GrapplerFunctionItem& grappler_function_item = it->second;
+ // Copy (not reference) so that changes we make here (e.g., replacing
+ // Placeholder with Const) don't affect one in
+ // fun_to_grappler_function_item_.
+ GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
@@ -453,7 +457,7 @@ class SymbolicShapeRefiner {
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
- const string& input = node->input(i);
+ const string& input = function_node->input(i);
const string& node_name = NodeName(input);
if (IsControlInput(input)) {
@@ -478,16 +482,35 @@ class SymbolicShapeRefiner {
TensorShapeProto proto;
const auto& handle = input_inference_context->output(output_port_num);
input_inference_context->ShapeHandleToProto(handle, &proto);
+ // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+ for (int i = 0; i < proto.dim_size(); i++) {
+ if (proto.dim(i).size() < -1) {
+ proto.mutable_dim(i)->set_size(-1);
+ }
+ }
*attr_output_shape.mutable_shape() = proto;
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
+ // Replace input Placeholders with Consts, if values are known. Note that
+ // we don't check exceptions here as it's done in the above loop.
+ for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+ const string& input = function_node->input(i);
+ const string& node_name = NodeName(input);
+ NodeDef* input_node = graph_.GetNode(node_name);
+ // TODO(dyoon): also use Const when output_tensors_as_shape is available.
+ if (IsConstant(*input_node)) {
+ TF_CHECK_OK(
+ ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ }
+ }
+
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(node);
+ auto ic = GetContext(function_node);
int output = 0;
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
@@ -505,8 +528,9 @@ class SymbolicShapeRefiner {
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
- return errors::FailedPrecondition("Unable to find return node ",
- node_name, " for ", node->name());
+ return errors::FailedPrecondition(
+ "Unable to find return function_node ", node_name, " for ",
+ function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -671,11 +695,13 @@ class SymbolicShapeRefiner {
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
- if (UpdateFunction(node).ok()) {
+ auto s = UpdateFunction(node);
+ if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
- << ". Defaulting to ShapeUnknown.";
+ << ". Defaulting to ShapeUnknown.\n"
+ << s.ToString();
}
}