aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar Jacques Pienaar <jpienaar@google.com>2018-05-07 16:59:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 18:00:16 -0700
commitdb63348bf14d911f2eebeb418a0b570b65b64f92 (patch)
tree4746004e23a92518a07c8de57acffd918025e0d6 /tensorflow/compiler/aot
parent3964bdeef88cb9f7824bbfc8ca4f44c7a4bd4dbd (diff)
Add test with tf.cond.
PiperOrigin-RevId: 195745718
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/tests/BUILD14
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py29
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt20
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc26
4 files changed, 79 insertions, 10 deletions
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index 222e26810a..fd2cf2b67d 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -15,6 +15,7 @@ test_suite(
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
":test_graph_tfassert_eq_test",
+ ":test_graph_tfcond_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@@ -55,6 +56,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
"test_graph_tfassert_eq.pb",
+ "test_graph_tfcond.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@@ -119,6 +121,17 @@ tf_library(
)
tf_library(
+ name = "test_graph_tfcond",
+ testonly = 1,
+ config = "test_graph_tfcond.config.pbtxt",
+ cpp_class = "CondComp",
+ graph = "test_graph_tfcond.pb",
+ tags = [
+ "manual",
+ ],
+)
+
+tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
@@ -194,6 +207,7 @@ tf_cc_test(
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
":test_graph_tfassert_eq",
+ ":test_graph_tfcond",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 67767f55da..9ec7df163b 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -78,6 +78,22 @@ def tfadd_with_ckpt_saver(out_dir):
f.write(saver.as_saver_def().SerializeToString())
+def tfassert_eq(_):
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ control_flow_ops.Assert(
+ math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+ math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
+def tfcond(_):
+ p = array_ops.placeholder(dtypes.bool, name='p_hold')
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ z = control_flow_ops.cond(p, lambda: x, lambda: y)
+ array_ops.identity(z, name='result')
+
+
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
@@ -126,14 +142,6 @@ def tfsplits(_):
array_ops.identity(y, name='result')
-def tfassert_eq(_):
- x = array_ops.placeholder(dtypes.int32, name='x_hold')
- y = array_ops.placeholder(dtypes.int32, name='y_hold')
- control_flow_ops.Assert(
- math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
- math_ops.add(x, math_ops.negative(y), name='x_y_diff')
-
-
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -148,12 +156,13 @@ def main(_):
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
+ write_graph(tfassert_eq, FLAGS.out_dir)
+ write_graph(tfcond, FLAGS.out_dir)
+ write_graph(tffunction, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
- write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
- write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
new file mode 100644
index 0000000000..94a01ad4ab
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfcond.config.pbtxt
@@ -0,0 +1,20 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "p_hold" }
+ shape {}
+}
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "result" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 27ba42b31f..309a991fc1 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfcond.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@@ -150,6 +151,31 @@ TEST(TFCompileTest, AddWithCkptSaver) {
EXPECT_EQ(add_const.result0_data(), add_const.results()[0]);
}
+TEST(TFCompileTest, Cond) {
+ CondComp cond;
+ EXPECT_EQ(cond.arg0_data(), cond.args()[0]);
+ EXPECT_EQ(cond.arg1_data(), cond.args()[1]);
+ EXPECT_EQ(cond.arg2_data(), cond.args()[2]);
+ cond.arg1() = 10;
+ cond.arg2() = 20;
+ {
+ cond.arg0() = true;
+ const int32 expected_result = cond.arg1();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+ {
+ cond.arg0() = false;
+ const int32 expected_result = cond.arg2();
+ EXPECT_TRUE(cond.Run());
+ EXPECT_EQ(cond.result0(), expected_result);
+ EXPECT_EQ(cond.result0_data()[0], expected_result);
+ EXPECT_EQ(cond.result0_data(), cond.results()[0]);
+ }
+}
+
TEST(TFCompileTest, Gather) {
GatherComp gather;
EXPECT_EQ(gather.arg0_data(), gather.args()[0]);