diff options
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r-- | tensorflow/core/grappler/costs/graph_properties.cc | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 06e91af2c2..31c1043ae6 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/costs/graph_properties.h" +#include <queue> +#include <unordered_map> +#include <unordered_set> #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -31,6 +34,76 @@ Status GraphProperties::InferStatically() { Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); + // List the resources and the nodes using them + std::unordered_map<const Node*, std::unordered_set<const Node*>> resources; + for (const Node* const node : graph.nodes()) { + for (int i = 0; i < node->num_inputs(); ++i) { + if (node->input_type(i) == DataType::DT_RESOURCE) { + const Node* resource; + TF_CHECK_OK(node->input_node(i, &resource)); + resources[resource].insert(node); + } + } + } + + // If we found a resource, try to propagate the shapes through it. + bool done = true; + do { + std::queue<const Node*> new_shapes; + for (const auto& resource_data : resources) { + const Node* qnode = resource_data.first; + StringPiece type(qnode->type_string()); + if (!type.ends_with("QueueV2")) { + continue; + } + auto qctx = shape_refiner.GetContext(qnode); + if (!qctx) { + continue; + } + DataType queue_type = qctx->output_handle_dtype(0); + shape_inference::ShapeHandle queue_shp = qctx->output_handle_shape(0); + if (qctx->FullyDefined(queue_shp) && queue_type != DT_INVALID) { + continue; + } + + for (const auto& node : resource_data.second) { + auto ctx = shape_refiner.GetContext(node); + if (!ctx) { + continue; + } + if (node->type_string().find("Enqueue") != std::string::npos) { + if (ctx->num_inputs() == 2) { + const DataType dtype = node->input_type(1); + if (queue_type == DT_INVALID) { + queue_type = dtype; + } else { + CHECK_EQ(queue_type, dtype); + } + shape_inference::ShapeHandle shp = ctx->input(1); + TF_RETURN_IF_ERROR(qctx->Merge(queue_shp, shp, &queue_shp)); + } + } + } + if (qctx->set_output_handle_dtype(0, queue_type) || + qctx->set_output_handle_shape(0, queue_shp)) { + new_shapes.push(qnode); + } + } + // Propagate the shapes in the transitive fan-out of the queue. + done = new_shapes.empty(); + while (!new_shapes.empty()) { + const Node* n = new_shapes.front(); + new_shapes.pop(); + for (const Node* fanout : n->out_nodes()) { + bool updated = false; + TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated)); + if (updated) { + new_shapes.push(fanout); + } + } + } + } while (!done); + for (const Node* const node : graph.nodes()) { VLOG(1) << "<Node> " << node->name(); auto ctx = shape_refiner.GetContext(node); |