aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/graph_view.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/graph_view.h')
-rw-r--r--tensorflow/core/grappler/graph_view.h10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..ec946ca3b5 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public: