aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/utils.cc
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-09-21 14:52:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-21 14:55:13 -0700
commit57498a86c11dfc98dda84dc7318a3c84c85c6791 (patch)
treeaf27c5e5c9a264ca86ad25faeb32948c89c618df /tensorflow/core/grappler/utils.cc
parent847aa2fec14e7cdde140a3e5fdb0c3229caf9426 (diff)
Fold fetch nodes.
PiperOrigin-RevId: 169604180
Diffstat (limited to 'tensorflow/core/grappler/utils.cc')
-rw-r--r--tensorflow/core/grappler/utils.cc22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index c8830e9b3c..63145b4e07 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include <memory>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
@@ -220,5 +223,24 @@ string AsControlDependency(const string& node) {
return strings::StrCat("^", node);
}
+int NumOutputs(const NodeDef& node) {
+ int num_outputs = 0;
+ const OpDef* op_def = nullptr;
+ auto status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
+ if (status.ok()) {
+ for (const auto& output : op_def->output_arg()) {
+ if (!output.type_list_attr().empty()) {
+ num_outputs +=
+ node.attr().at(output.type_list_attr()).list().type_size();
+ } else if (!output.number_attr().empty()) {
+ num_outputs += node.attr().at(output.number_attr()).i();
+ } else {
+ num_outputs++;
+ }
+ }
+ }
+ return num_outputs;
+}
+
} // end namespace grappler
} // end namespace tensorflow