diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.cc | 171 |
1 files changed, 157 insertions, 14 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 0c02876ac5..231c7c63be 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -422,11 +423,106 @@ class SymbolicShapeRefiner { return it->second.inference_context.get(); } - // Forward the shapes from the function's fanin to the function body, - // then call PropagateShapes. - // Returns an error if 'node' is not a function node. - Status UpdateFunction(const NodeDef* node, bool* refined) { - return UpdateNode(node, refined); + // Forward the shapes from the function input nodes to + // the argument nodes (which are Placeholder nodes), then + // perform shape inference on the function body. + // + // Propagate shape information of final function body node + // to function node `node`. + // + // In the event of an error, UpdateNode will simply set `node`'s + // output shape to be Unknown. + Status UpdateFunction(const NodeDef* node) { + auto it = fun_to_grappler_function_item_.find(node->op()); + if (it == fun_to_grappler_function_item_.end()) { + return errors::InvalidArgument( + node->op(), " was not previously added to SymbolicShapeRefiner."); + } + + GrapplerFunctionItem& grappler_function_item = it->second; + GraphView gv(&grappler_function_item.graph); + + // Forward shapes from function input nodes to argument nodes. + for (int i = 0; i < grappler_function_item.inputs().size(); ++i) { + auto& fun_input = grappler_function_item.input(i); + if (fun_input.placeholders.size() > 1) { + // TODO(jmdecker): Handle case with multiple input placeholders + return errors::Unimplemented( + "Input arguments with multiple placeholders are not yet " + "supported."); + } + NodeDef* fun_node = gv.GetNode(fun_input.input_name); + const string& input = node->input(i); + const string& node_name = NodeName(input); + + if (IsControlInput(input)) { + return errors::FailedPrecondition( + "Function inputs should not contain control nodes."); + } + + NodeDef* input_node = graph_.GetNode(node_name); + if (input_node == nullptr) { + return errors::FailedPrecondition(node_name, + " was not found in the graph."); + } + + InferenceContext* input_inference_context = GetContext(input_node); + if (input_inference_context == nullptr) { + return errors::FailedPrecondition( + "Inference context has not been created for ", node_name); + } + + int output_port_num = NodePosition(input); + AttrValue attr_output_shape; + TensorShapeProto proto; + const auto& handle = input_inference_context->output(output_port_num); + input_inference_context->ShapeHandleToProto(handle, &proto); + *attr_output_shape.mutable_shape() = proto; + (*fun_node->mutable_attr())["shape"] = attr_output_shape; + } + + // Perform inference on function body. + GraphProperties gp(grappler_function_item); + TF_RETURN_IF_ERROR(gp.InferStatically(true)); + + // Add return nodes for output shapes. + auto ic = GetContext(node); + int output = 0; + for (auto const& out_arg : grappler_function_item.outputs()) { + if (out_arg.output_tensors.size() > 1) { + // TODO(jmdecker): Handle case of multiple output tensors + return errors::Unimplemented( + "Output arguments with multiple output tensors are not yet " + "supported."); + } + + // It is guaranteed that output_tensors does not contain any control + // inputs, so port_id >= 0. + string out_tensor = out_arg.output_tensors[0]; + int port_id; + string node_name = ParseNodeName(out_tensor, &port_id); + + const NodeDef* retnode = gv.GetNode(node_name); + if (retnode == nullptr) { + return errors::FailedPrecondition("Unable to find return node ", + node_name, " for ", node->name()); + } + + auto output_properties = gp.GetOutputProperties(retnode->name()); + if (port_id >= output_properties.size()) { + return errors::InvalidArgument( + out_tensor, " has invalid position ", port_id, + " (output_properties.size() = ", output_properties.size(), ")."); + } + auto const& outprop = output_properties[port_id]; + const TensorShapeProto& shape = outprop.shape(); + ShapeHandle out; + TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out)); + ic->set_output(output, out); + output++; + } + + return Status::OK(); } Status UpdateNode(const NodeDef* node, bool* refined) { @@ -436,6 +532,7 @@ class SymbolicShapeRefiner { node_context = CHECK_NOTNULL(GetNodeContext(node)); *refined = true; } + // Check if the shapes of the nodes in the fan-in of this node have changed, // and if they have, update the node input shapes. InferenceContext* inference_context = node_context->inference_context.get(); @@ -455,7 +552,8 @@ class SymbolicShapeRefiner { if (c == nullptr) { return errors::FailedPrecondition( "Input ", dst_input, " ('", input->name(), "') for '", - node->name(), "' was not previously added to ShapeRefiner."); + node->name(), + "' was not previously added to SymbolicShapeRefiner."); } if (IsConstant(*input)) { @@ -565,6 +663,21 @@ class SymbolicShapeRefiner { node_context->inference_context->set_input_tensors_as_shapes( input_tensors_as_shapes); + // Properly handle function nodes. + if (node_context->op_data && node_context->op_data->is_function_op) { + // TODO(jmdecker): Detect if the input shapes have changed for this + // function. Note that when we hit a function call node, refined will be + // true, as the updates to the call node will have changed, even if it's + // the same function being called twice with the same input shapes. + // Example: simple_function.pbtxt + if (UpdateFunction(node).ok()) { + return Status::OK(); + } else { + VLOG(1) << "UpdateFunction failed for " << node->op() + << ". Defaulting to ShapeUnknown."; + } + } + // Update the shapes of the outputs. return InferShapes(*node, node_context); } @@ -681,7 +794,39 @@ class SymbolicShapeRefiner { return true; } - Status AddFunction(const NodeDef* node) { return Status::OK(); } + Status AddFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); + if (it != fun_to_grappler_function_item_.end()) { + return Status::OK(); + } + + const FunctionDef* function_def = + CHECK_NOTNULL(function_library_.Find(function_node->op())); + + GrapplerFunctionItem grappler_function_item; + TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( + *function_def, function_library_, &grappler_function_item)); + + if (grappler_function_item.inputs().size() > function_node->input_size()) { + return errors::FailedPrecondition( + "Function input size should be smaller than node input size."); + } + + for (int i = grappler_function_item.inputs().size(); + i < function_node->input_size(); ++i) { + const string& input = function_node->input(i); + if (!IsControlInput(input)) { + return errors::FailedPrecondition( + "Found regular input (", input, + ") instead of control nodes for node ", function_node->name()); + } + } + + fun_to_grappler_function_item_[function_def->signature().name()] = + grappler_function_item; + + return Status::OK(); + } Status AddNode(const NodeDef* node) { NodeContext& node_ctx = node_to_context_[node]; @@ -911,6 +1056,8 @@ class SymbolicShapeRefiner { std::unordered_map<const NodeDef*, NodeContext> node_to_context_; std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_; std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_; + std::unordered_map<string, GrapplerFunctionItem> + fun_to_grappler_function_item_; FunctionLibraryDefinition function_library_; const std::unordered_map<string, std::unordered_set<int>>& fed_ports_; }; @@ -1082,13 +1229,9 @@ Status GraphProperties::UpdateShapes( // Set shapes and types of Queue ops, if needed. TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes)); } else { - auto c = shape_refiner->GetNodeContext(n); - if (c && c->op_data && c->op_data->is_function_op) { - TF_RETURN_IF_ERROR(shape_refiner->UpdateFunction(n, new_shapes)); - } else { - // Rely on regular TF shape refinement for all the other nodes. - TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); - } + // Rely on regular TF shape refinement for all the other nodes. + // UpdateNode calls UpdateFunction if a function node is detected. + TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes)); } return Status::OK(); } |