aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-14 20:39:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-14 20:44:02 -0700
commit9037e241de1e64044ff55ab539ccc1fb013c178a (patch)
treef7b8bda19a5efdd57f99ce9cd7b0bf6fed211628 /tensorflow/core/grappler/utils.cc
parent357cd4b8b2f960520fc57b6cfbf41117a2a20fc7 (diff)
Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN, if all shapes are symbolically equal 2) Lookup shape properties using GraphProperties instead of direct access to Node attributes PiperOrigin-RevId: 189131726
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index eb1f882ff1..829bfe9e31 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) {
tensor->flat<T>()(0) = static_cast<T>(value);
return true;
}
+
+// Is 'node' an operator that consumes only the shape of its input, not the
+// data itself?
+// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency.
+// TODO(ezhulenev): what about Identity passing tensor to Shape consumer?
+bool IsShapeConsumer(const NodeDef& node) {
+ const string& op = node.op();
+ return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size";
+}
+
} // namespace
NodeMap::NodeMap(GraphDef* graph) {
@@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) {
return num_outputs;
}
+int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) {
+ int num_data_outputs = 0;
+ for (const NodeDef* output : node_map.GetOutputs(node.name())) {
+ if (IsShapeConsumer(*output)) continue;
+
+ for (int i = 0; i < output->input_size(); ++i) {
+ const string& input = output->input(i);
+ if (!IsControlInput(input) && NodeName(input) == node.name()) {
+ ++num_data_outputs;
+ break;
+ }
+ }
+ }
+ return num_data_outputs;
+}
+
// Returns the data type in attribute `attr_name` of `node`. If that attribute
// doesn't exist, returns DT_INVALID.
DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) {