aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Adrian Kuegel <akuegel@google.com>2018-09-20 00:41:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 00:45:35 -0700
commit31c0857f6b5d79f4a7b16ee4af85f0bde8b5f5da (patch)
treea7cb98f6db58c513fd8c689600053745131303ef /tensorflow/compiler/aot
parent2ea398b12ed18b6c51e09f363021c6aa306c5179 (diff)
Add AOT test case for XlaSort.
The only tensorflow op that uses XlaSort is nn.top_k, so we add a test case using nn.top_k. PiperOrigin-RevId: 213763591
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/tests/BUILD15
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py8
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt13
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc25
-rw-r--r--tensorflow/compiler/aot/tfcompile.bzl1
5 files changed, 62 insertions, 0 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 7a0932d44d..10fa33ab5e 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -25,6 +25,7 @@ test_suite(
":test_graph_tfmatmul_test",
":test_graph_tfmatmulandadd_test",
":test_graph_tfsplits_test",
+ ":test_graph_tftop_k_test",
":tfcompile_test",
],
)
@@ -42,6 +43,7 @@ py_binary(
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
"//tensorflow/python:platform",
"//tensorflow/python:session",
"//tensorflow/python:training",
@@ -66,6 +68,7 @@ genrule(
"test_graph_tfmatmul.pb",
"test_graph_tfmatmulandadd.pb",
"test_graph_tfsplits.pb",
+ "test_graph_tftop_k.pb",
],
# Set CUDA_VISIBLE_DEVICES='' to prevent the code we launch from using any
# GPUs which might be present. This is important because builds may run
@@ -208,6 +211,17 @@ tf_library(
],
)
+tf_library(
+ name = "test_graph_tftop_k",
+ testonly = 1,
+ config = "test_graph_tftop_k.config.pbtxt",
+ cpp_class = "TopKComp",
+ graph = "test_graph_tftop_k.pb",
+ tags = [
+ "manual",
+ ],
+)
+
tf_cc_test(
name = "tfcompile_test",
srcs = ["tfcompile_test.cc"],
@@ -226,6 +240,7 @@ tf_cc_test(
":test_graph_tfmatmulandadd",
":test_graph_tfmatmulandadd_with_profiling",
":test_graph_tfsplits",
+ ":test_graph_tftop_k",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 9ec7df163b..de135d7a23 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -31,6 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.training import saver as saver_lib
@@ -142,6 +143,12 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tftop_k(_):
+ x = array_ops.placeholder(dtypes.int32, shape=[5], name='x')
+ output = nn_ops.top_k(x, 2, name='values')
+ array_ops.identity(output[1], name='indices')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -163,6 +170,7 @@ def main(_):
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tftop_k, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
new file mode 100644
index 0000000000..6b4ac2d7cb
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tftop_k.config.pbtxt
@@ -0,0 +1,13 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x" }
+ shape {
+ dim { size: 5 }
+ }
+}
+fetch {
+ id { node_name: "values" }
+}
+fetch {
+ id { node_name: "indices" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 7ac90fb8a9..f10852c785 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h"
#include "tensorflow/compiler/xla/service/hlo_profile_printer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@@ -448,6 +449,30 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, TopK) {
+ Eigen::ThreadPool tp(1);
+ Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
+
+ TopKComp fn;
+
+ fn.set_thread_pool(&device);
+ // x = [4, 1, 4, 4, 3]
+ fn.arg0(0) = 4;
+ fn.arg0(1) = 1;
+ fn.arg0(2) = 4;
+ fn.arg0(3) = 4;
+ fn.arg0(4) = 3;
+
+ EXPECT_TRUE(fn.Run());
+ EXPECT_EQ(fn.error_msg(), "");
+ const int32 expected_values[] = {4, 4};
+ const int32 expected_indices[] = {0, 2};
+ EXPECT_EQ(expected_values[0], fn.result0(0));
+ EXPECT_EQ(expected_values[1], fn.result0(1));
+ EXPECT_EQ(expected_indices[0], fn.result1(0));
+ EXPECT_EQ(expected_indices[1], fn.result1(1));
+}
+
TEST(TFCompileTest, AssertEqAndReturnDiff) {
// Assert is converted into a no-op in XLA, so there is no failure even if the
// two args are different.
diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl
index 792b7fe14a..859c84bb91 100644
--- a/tensorflow/compiler/aot/tfcompile.bzl
+++ b/tensorflow/compiler/aot/tfcompile.bzl
@@ -273,6 +273,7 @@ def tf_library(
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
+ "//tensorflow/compiler/xla/service/cpu:runtime_key_value_sort",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul",