aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/aot
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-12 17:07:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-12 17:15:33 -0700
commitfffd3ca4fcf1f54f97a7be6f225fe183ad82b0ea (patch)
treec4a0d0de21cd133c3586140e92280aa5d0b5d50c /tensorflow/compiler/aot
parent3755128f3a83fea84c5a90d71d5b684157a99ac7 (diff)
Move dummy AssertOp and CheckNumericsOp to //third_party/tensorflow/compiler/tf2xla/kernels.
Enable type DT_STRING for AssertOp and ConstOp, in order to make dummy Assert compile with a const string (assert message) as its input. PiperOrigin-RevId: 192695938
Diffstat (limited to 'tensorflow/compiler/aot')
-rw-r--r--tensorflow/compiler/aot/BUILD1
-rw-r--r--tensorflow/compiler/aot/tests/BUILD15
-rw-r--r--tensorflow/compiler/aot/tests/make_test_graphs.py10
-rw-r--r--tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt16
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc18
5 files changed, 60 insertions, 0 deletions
diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD
index fa03b1f3c2..19e6bf68e7 100644
--- a/tensorflow/compiler/aot/BUILD
+++ b/tensorflow/compiler/aot/BUILD
@@ -60,6 +60,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:tf2xla_util",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops",
+ "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD
index b053dad1b5..bb73cb19c5 100644
--- a/tensorflow/compiler/aot/tests/BUILD
+++ b/tensorflow/compiler/aot/tests/BUILD
@@ -14,6 +14,7 @@ test_suite(
":test_graph_tfadd_test",
":test_graph_tfadd_with_ckpt_saver_test",
":test_graph_tfadd_with_ckpt_test",
+ ":test_graph_tfassert_eq_test",
":test_graph_tffunction_test",
":test_graph_tfgather_test",
":test_graph_tfmatmul_test",
@@ -33,6 +34,7 @@ py_binary(
"//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
+ "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform",
@@ -52,6 +54,7 @@ genrule(
"test_graph_tfadd_with_ckpt_saver.ckpt",
"test_graph_tfadd_with_ckpt_saver.pb",
"test_graph_tfadd_with_ckpt_saver.saver",
+ "test_graph_tfassert_eq.pb",
"test_graph_tffunction.pb",
"test_graph_tfgather.pb",
"test_graph_tfmatmul.pb",
@@ -105,6 +108,17 @@ tf_library(
)
tf_library(
+ name = "test_graph_tfassert_eq",
+ testonly = 1,
+ config = "test_graph_tfassert_eq.config.pbtxt",
+ cpp_class = "AssertComp",
+ graph = "test_graph_tfassert_eq.pb",
+ tags = [
+ "manual",
+ ],
+)
+
+tf_library(
name = "test_graph_tffunction",
testonly = 1,
config = "test_graph_tffunction.config.pbtxt",
@@ -170,6 +184,7 @@ tf_cc_test(
":test_graph_tfadd",
":test_graph_tfadd_with_ckpt",
":test_graph_tfadd_with_ckpt_saver",
+ ":test_graph_tfassert_eq",
":test_graph_tffunction",
":test_graph_tfgather",
":test_graph_tfmatmul",
diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py
index 89c7cd4507..67767f55da 100644
--- a/tensorflow/compiler/aot/tests/make_test_graphs.py
+++ b/tensorflow/compiler/aot/tests/make_test_graphs.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
@@ -125,6 +126,14 @@ def tfsplits(_):
array_ops.identity(y, name='result')
+def tfassert_eq(_):
+ x = array_ops.placeholder(dtypes.int32, name='x_hold')
+ y = array_ops.placeholder(dtypes.int32, name='y_hold')
+ control_flow_ops.Assert(
+ math_ops.equal(x, y), ['Expected x == y.'], name='assert_eq')
+ math_ops.add(x, math_ops.negative(y), name='x_y_diff')
+
+
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
@@ -144,6 +153,7 @@ def main(_):
write_graph(tfmatmulandadd, FLAGS.out_dir)
write_graph(tffunction, FLAGS.out_dir)
write_graph(tfsplits, FLAGS.out_dir)
+ write_graph(tfassert_eq, FLAGS.out_dir)
if __name__ == '__main__':
diff --git a/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt
new file mode 100644
index 0000000000..8732d1709e
--- /dev/null
+++ b/tensorflow/compiler/aot/tests/test_graph_tfassert_eq.config.pbtxt
@@ -0,0 +1,16 @@
+# Text form of tensorflow.tf2xla.Config proto.
+feed {
+ id { node_name: "x_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+feed {
+ id { node_name: "y_hold" }
+ shape {
+ dim { size: 1 }
+ }
+}
+fetch {
+ id { node_name: "x_y_diff" }
+}
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 413efd9cea..67dbd643bf 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/aot/tests/test_graph_tfadd.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt_saver.h"
+#include "tensorflow/compiler/aot/tests/test_graph_tfassert_eq.h"
#include "tensorflow/compiler/aot/tests/test_graph_tffunction.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfgather.h"
#include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h"
@@ -413,6 +414,23 @@ TEST(TFCompileTest, Splits) {
EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4);
}
+TEST(TFCompileTest, AssertEqAndReturnDiff) {
+ // Assert is converted into a no-op in XLA, so there is no failure even if the
+ // two args are different.
+ AssertComp assert;
+ EXPECT_EQ(assert.arg0_data(), assert.args()[0]);
+ EXPECT_EQ(assert.arg1_data(), assert.args()[1]);
+
+ assert.arg0() = 2;
+ assert.arg1() = 1;
+ const int32 expected_result = assert.arg0() - assert.arg1();
+ EXPECT_TRUE(assert.Run());
+ EXPECT_EQ(assert.error_msg(), "");
+ EXPECT_EQ(assert.result0(), expected_result);
+ EXPECT_EQ(assert.result0_data()[0], expected_result);
+ EXPECT_EQ(assert.result0_data(), assert.results()[0]);
+}
+
TEST(TFCompileTest, LookupNameIndex) {
// add doesn't have any names defined in its config.
AddComp add;