aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/costs/graph_properties.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/costs/graph_properties.cc')
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc73
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 06e91af2c2..31c1043ae6 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -31,6 +34,76 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
+ // List the resources and the nodes using them
+ std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
+ for (const Node* const node : graph.nodes()) {
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (node->input_type(i) == DataType::DT_RESOURCE) {
+ const Node* resource;
+ TF_CHECK_OK(node->input_node(i, &resource));
+ resources[resource].insert(node);
+ }
+ }
+ }
+
+ // If we found a resource, try to propagate the shapes through it.
+ bool done = true;
+ do {
+ std::queue<const Node*> new_shapes;
+ for (const auto& resource_data : resources) {
+ const Node* qnode = resource_data.first;
+ StringPiece type(qnode->type_string());
+ if (!type.ends_with("QueueV2")) {
+ continue;
+ }
+ auto qctx = shape_refiner.GetContext(qnode);
+ if (!qctx) {
+ continue;
+ }
+ DataType queue_type = qctx->output_handle_dtype(0);
+ shape_inference::ShapeHandle queue_shp = qctx->output_handle_shape(0);
+ if (qctx->FullyDefined(queue_shp) && queue_type != DT_INVALID) {
+ continue;
+ }
+
+ for (const auto& node : resource_data.second) {
+ auto ctx = shape_refiner.GetContext(node);
+ if (!ctx) {
+ continue;
+ }
+ if (node->type_string().find("Enqueue") != std::string::npos) {
+ if (ctx->num_inputs() == 2) {
+ const DataType dtype = node->input_type(1);
+ if (queue_type == DT_INVALID) {
+ queue_type = dtype;
+ } else {
+ CHECK_EQ(queue_type, dtype);
+ }
+ shape_inference::ShapeHandle shp = ctx->input(1);
+ TF_RETURN_IF_ERROR(qctx->Merge(queue_shp, shp, &queue_shp));
+ }
+ }
+ }
+ if (qctx->set_output_handle_dtype(0, queue_type) ||
+ qctx->set_output_handle_shape(0, queue_shp)) {
+ new_shapes.push(qnode);
+ }
+ }
+ // Propagate the shapes in the transitive fan-out of the queue.
+ done = new_shapes.empty();
+ while (!new_shapes.empty()) {
+ const Node* n = new_shapes.front();
+ new_shapes.pop();
+ for (const Node* fanout : n->out_nodes()) {
+ bool updated = false;
+ TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated));
+ if (updated) {
+ new_shapes.push(fanout);
+ }
+ }
+ }
+ } while (!done);
+
for (const Node* const node : graph.nodes()) {
VLOG(1) << "<Node> " << node->name();
auto ctx = shape_refiner.GetContext(node);