aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/tests/BUILD12
-rw-r--r--tensorflow/compiler/tests/placeholder_test.py48
-rw-r--r--tensorflow/compiler/tf2xla/kernels/identity_op.cc1
3 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index ac2441cea0..0c72093256 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -923,3 +923,15 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "placeholder_test",
+ size = "small",
+ srcs = ["placeholder_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
new file mode 100644
index 0000000000..5e6d1313bd
--- /dev/null
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for xla handling of placeholder_with_default."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+class PlaceholderTest(XLATestCase):
+
+ def test_placeholder_with_default_default(self):
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(4.0)
+ ph = array_ops.placeholder_with_default(v, shape=[])
+ out = ph * 2
+ sess.run(variables.variables_initializer([v]))
+ self.assertEqual(8.0, sess.run(out))
+
+ def test_placeholder_with_default_fed(self):
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(4.0)
+ ph = array_ops.placeholder_with_default(v, shape=[])
+ out = ph * 2
+ sess.run(variables.variables_initializer([v]))
+ self.assertEqual(2.0, sess.run(out, {ph: 1.0}))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index 39af662b63..e72200bfbc 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -38,6 +38,7 @@ class IdentityOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp);
REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp);
+REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
REGISTER_XLA_OP(Name("Snapshot"), IdentityOp);