diff options
Diffstat (limited to 'tensorflow/compiler/aot/tests/make_test_graphs.py')
-rw-r--r-- | tensorflow/compiler/aot/tests/make_test_graphs.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 9ec7df163b..de135d7a23 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app from tensorflow.python.training import saver as saver_lib @@ -142,6 +143,12 @@ def tfsplits(_): array_ops.identity(y, name='result') +def tftop_k(_): + x = array_ops.placeholder(dtypes.int32, shape=[5], name='x') + output = nn_ops.top_k(x, 2, name='values') + array_ops.identity(output[1], name='indices') + + def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() @@ -163,6 +170,7 @@ def main(_): write_graph(tfmatmul, FLAGS.out_dir) write_graph(tfmatmulandadd, FLAGS.out_dir) write_graph(tfsplits, FLAGS.out_dir) + write_graph(tftop_k, FLAGS.out_dir) if __name__ == '__main__': |