diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-14 20:39:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-14 20:44:02 -0700 |
commit | 9037e241de1e64044ff55ab539ccc1fb013c178a (patch) | |
tree | f7b8bda19a5efdd57f99ce9cd7b0bf6fed211628 /tensorflow/core/grappler/utils.cc | |
parent | 357cd4b8b2f960520fc57b6cfbf41117a2a20fc7 (diff) |
Enable Add/AddN tree rewrite for symbolically equal shapes.
1) Rewrite a tree of Add/AddN ops with a single AddN,
if all shapes are symbolically equal
2) Lookup shape properties using GraphProperties instead
of direct access to Node attributes
PiperOrigin-RevId: 189131726
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index eb1f882ff1..829bfe9e31 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -40,6 +40,16 @@ bool SafeSetScalarTensorValue(double value, Tensor* tensor) { tensor->flat<T>()(0) = static_cast<T>(value); return true; } + +// Is 'node' an operator that consumes only the shape of its input, not the +// data itself? +// TODO(ezhulenev): move to op_types.h. Requires to break circular dependency. +// TODO(ezhulenev): what about Identity passing tensor to Shape consumer? +bool IsShapeConsumer(const NodeDef& node) { + const string& op = node.op(); + return op == "Shape" || op == "ShapeN" || op == "Rank" || op == "Size"; +} + } // namespace NodeMap::NodeMap(GraphDef* graph) { @@ -270,6 +280,22 @@ int NumNonControlOutputs(const NodeDef& node, const NodeMap& node_map) { return num_outputs; } +int NumNonControlDataOutputs(const NodeDef& node, const NodeMap& node_map) { + int num_data_outputs = 0; + for (const NodeDef* output : node_map.GetOutputs(node.name())) { + if (IsShapeConsumer(*output)) continue; + + for (int i = 0; i < output->input_size(); ++i) { + const string& input = output->input(i); + if (!IsControlInput(input) && NodeName(input) == node.name()) { + ++num_data_outputs; + break; + } + } + } + return num_data_outputs; +} + // Returns the data type in attribute `attr_name` of `node`. If that attribute // doesn't exist, returns DT_INVALID. DataType GetDataTypeFromAttr(const NodeDef& node, const string& attr_name) { |