aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot/tests/make_test_graphs.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/aot/tests/make_test_graphs.py')
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..de135d7a23 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':