diff options
author | 2017-09-21 14:52:28 -0700 | |
---|---|---|
committer | 2017-09-21 14:55:13 -0700 | |
commit | 57498a86c11dfc98dda84dc7318a3c84c85c6791 (patch) | |
tree | af27c5e5c9a264ca86ad25faeb32948c89c618df /tensorflow/core/grappler/utils.cc | |
parent | 847aa2fec14e7cdde140a3e5fdb0c3229caf9426 (diff) |
Fold fetch nodes.
PiperOrigin-RevId: 169604180
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r-- | tensorflow/core/grappler/utils.cc | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index c8830e9b3c..63145b4e07 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -15,6 +15,9 @@ limitations under the License. #include <memory> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" @@ -220,5 +223,24 @@ string AsControlDependency(const string& node) { return strings::StrCat("^", node); } +int NumOutputs(const NodeDef& node) { + int num_outputs = 0; + const OpDef* op_def = nullptr; + auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + if (status.ok()) { + for (const auto& output : op_def->output_arg()) { + if (!output.type_list_attr().empty()) { + num_outputs += + node.attr().at(output.type_list_attr()).list().type_size(); + } else if (!output.number_attr().empty()) { + num_outputs += node.attr().at(output.number_attr()).i(); + } else { + num_outputs++; + } + } + } + return num_outputs; +} + } // end namespace grappler } // end namespace tensorflow |