aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc171
1 files changed, 157 insertions, 14 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 0c02876ac5..231c7c63be 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -422,11 +423,106 @@ class SymbolicShapeRefiner {
return it->second.inference_context.get();
}
- // Forward the shapes from the function's fanin to the function body,
- // then call PropagateShapes.
- // Returns an error if 'node' is not a function node.
- Status UpdateFunction(const NodeDef* node, bool* refined) {
- return UpdateNode(node, refined);
+ // Forward the shapes from the function input nodes to
+ // the argument nodes (which are Placeholder nodes), then
+ // perform shape inference on the function body.
+ //
+ // Propagate shape information of final function body node
+ // to function node `node`.
+ //
+ // In the event of an error, UpdateNode will simply set `node`'s
+ // output shape to be Unknown.
+ Status UpdateFunction(const NodeDef* node) {
+ auto it = fun_to_grappler_function_item_.find(node->op());
+ if (it == fun_to_grappler_function_item_.end()) {
+ return errors::InvalidArgument(
+ node->op(), " was not previously added to SymbolicShapeRefiner.");
+ }
+
+ GrapplerFunctionItem& grappler_function_item = it->second;
+ GraphView gv(&grappler_function_item.graph);
+
+ // Forward shapes from function input nodes to argument nodes.
+ for (int i = 0; i < grappler_function_item.inputs().size(); ++i) {
+ auto& fun_input = grappler_function_item.input(i);
+ if (fun_input.placeholders.size() > 1) {
+ // TODO(jmdecker): Handle case with multiple input placeholders
+ return errors::Unimplemented(
+ "Input arguments with multiple placeholders are not yet "
+ "supported.");
+ }
+ NodeDef* fun_node = gv.GetNode(fun_input.input_name);
+ const string& input = node->input(i);
+ const string& node_name = NodeName(input);
+
+ if (IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Function inputs should not contain control nodes.");
+ }
+
+ NodeDef* input_node = graph_.GetNode(node_name);
+ if (input_node == nullptr) {
+ return errors::FailedPrecondition(node_name,
+ " was not found in the graph.");
+ }
+
+ InferenceContext* input_inference_context = GetContext(input_node);
+ if (input_inference_context == nullptr) {
+ return errors::FailedPrecondition(
+ "Inference context has not been created for ", node_name);
+ }
+
+ int output_port_num = NodePosition(input);
+ AttrValue attr_output_shape;
+ TensorShapeProto proto;
+ const auto& handle = input_inference_context->output(output_port_num);
+ input_inference_context->ShapeHandleToProto(handle, &proto);
+ *attr_output_shape.mutable_shape() = proto;
+ (*fun_node->mutable_attr())["shape"] = attr_output_shape;
+ }
+
+ // Perform inference on function body.
+ GraphProperties gp(grappler_function_item);
+ TF_RETURN_IF_ERROR(gp.InferStatically(true));
+
+ // Add return nodes for output shapes.
+ auto ic = GetContext(node);
+ int output = 0;
+ for (auto const& out_arg : grappler_function_item.outputs()) {
+ if (out_arg.output_tensors.size() > 1) {
+ // TODO(jmdecker): Handle case of multiple output tensors
+ return errors::Unimplemented(
+ "Output arguments with multiple output tensors are not yet "
+ "supported.");
+ }
+
+ // It is guaranteed that output_tensors does not contain any control
+ // inputs, so port_id >= 0.
+ string out_tensor = out_arg.output_tensors[0];
+ int port_id;
+ string node_name = ParseNodeName(out_tensor, &port_id);
+
+ const NodeDef* retnode = gv.GetNode(node_name);
+ if (retnode == nullptr) {
+ return errors::FailedPrecondition("Unable to find return node ",
+ node_name, " for ", node->name());
+ }
+
+ auto output_properties = gp.GetOutputProperties(retnode->name());
+ if (port_id >= output_properties.size()) {
+ return errors::InvalidArgument(
+ out_tensor, " has invalid position ", port_id,
+ " (output_properties.size() = ", output_properties.size(), ").");
+ }
+ auto const& outprop = output_properties[port_id];
+ const TensorShapeProto& shape = outprop.shape();
+ ShapeHandle out;
+ TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
+ ic->set_output(output, out);
+ output++;
+ }
+
+ return Status::OK();
}
Status UpdateNode(const NodeDef* node, bool* refined) {
@@ -436,6 +532,7 @@ class SymbolicShapeRefiner {
node_context = CHECK_NOTNULL(GetNodeContext(node));
*refined = true;
}
+
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have, update the node input shapes.
InferenceContext* inference_context = node_context->inference_context.get();
@@ -455,7 +552,8 @@ class SymbolicShapeRefiner {
if (c == nullptr) {
return errors::FailedPrecondition(
"Input ", dst_input, " ('", input->name(), "') for '",
- node->name(), "' was not previously added to ShapeRefiner.");
+ node->name(),
+ "' was not previously added to SymbolicShapeRefiner.");
}
if (IsConstant(*input)) {
@@ -565,6 +663,21 @@ class SymbolicShapeRefiner {
node_context->inference_context->set_input_tensors_as_shapes(
input_tensors_as_shapes);
+ // Properly handle function nodes.
+ if (node_context->op_data && node_context->op_data->is_function_op) {
+ // TODO(jmdecker): Detect if the input shapes have changed for this
+ // function. Note that when we hit a function call node, refined will be
+ // true, as the updates to the call node will have changed, even if it's
+ // the same function being called twice with the same input shapes.
+ // Example: simple_function.pbtxt
+ if (UpdateFunction(node).ok()) {
+ return Status::OK();
+ } else {
+ VLOG(1) << "UpdateFunction failed for " << node->op()
+ << ". Defaulting to ShapeUnknown.";
+ }
+ }
+
// Update the shapes of the outputs.
return InferShapes(*node, node_context);
}
@@ -681,7 +794,39 @@ class SymbolicShapeRefiner {
return true;
}
- Status AddFunction(const NodeDef* node) { return Status::OK(); }
+ Status AddFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
+ if (it != fun_to_grappler_function_item_.end()) {
+ return Status::OK();
+ }
+
+ const FunctionDef* function_def =
+ CHECK_NOTNULL(function_library_.Find(function_node->op()));
+
+ GrapplerFunctionItem grappler_function_item;
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ *function_def, function_library_, &grappler_function_item));
+
+ if (grappler_function_item.inputs().size() > function_node->input_size()) {
+ return errors::FailedPrecondition(
+ "Function input size should be smaller than node input size.");
+ }
+
+ for (int i = grappler_function_item.inputs().size();
+ i < function_node->input_size(); ++i) {
+ const string& input = function_node->input(i);
+ if (!IsControlInput(input)) {
+ return errors::FailedPrecondition(
+ "Found regular input (", input,
+ ") instead of control nodes for node ", function_node->name());
+ }
+ }
+
+ fun_to_grappler_function_item_[function_def->signature().name()] =
+ grappler_function_item;
+
+ return Status::OK();
+ }
Status AddNode(const NodeDef* node) {
NodeContext& node_ctx = node_to_context_[node];
@@ -911,6 +1056,8 @@ class SymbolicShapeRefiner {
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
+ std::unordered_map<string, GrapplerFunctionItem>
+ fun_to_grappler_function_item_;
FunctionLibraryDefinition function_library_;
const std::unordered_map<string, std::unordered_set<int>>& fed_ports_;
};
@@ -1082,13 +1229,9 @@ Status GraphProperties::UpdateShapes(
// Set shapes and types of Queue ops, if needed.
TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
} else {
- auto c = shape_refiner->GetNodeContext(n);
- if (c && c->op_data && c->op_data->is_function_op) {
- TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes));
- } else {
- // Rely on regular TF shape refinement for all the other nodes.
- TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
- }
+ // Rely on regular TF shape refinement for all the other nodes.
+ // UpdateNode calls UpdateFunction if a function node is detected.
+ TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
}
return Status::OK();
}