aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-08-03 11:43:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 11:48:05 -0700
commitdd45704d092dac87575b8ce39013f91f4f213dc0 (patch)
treefa603f0dd9386890638fe3c5cb834a66c1b06888 /tensorflow
parent4e4171bb6fcc08c00dbc6f6ae2dcdc502add6931 (diff)
Unbreaks tests broken after the defun while loop change.
Do not add placeholders to the function body as XLA cannot compile them. PiperOrigin-RevId: 207299427
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/tests/eager_test.py2
-rw-r--r--tensorflow/python/eager/function.py7
2 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index 116ac472b8..6ead15da13 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -384,7 +384,6 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy())
def testDefunInGradientTape(self):
- self.skipTest('b/112172115: Broken by loop support in defun')
with self.test_scope():
v0 = resource_variable_ops.ResourceVariable(5.0)
@@ -402,7 +401,6 @@ class EagerFunctionTest(xla_test.XLATestCase):
self.assertEqual(30, dy.numpy())
def testSliceInDefun(self):
- self.skipTest('b/112172115: Broken by loop support in defun')
with self.test_scope():
@function.defun
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 8e8c028f60..51ebcd65b3 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -511,8 +511,13 @@ class GraphModeFunction(object):
extra_placeholders = []
forward_name = _forward_name(self._func_name)
+ # Note: we cannot have placeholder ops in the graph or the TPU compilation
+ # pass fails.
+ placeholder_ops = set([y.op for y in self._input_placeholders])
+ function_ops = [x for x in self._graph.get_operations()
+ if x not in placeholder_ops]
self._forward_fdef = _EagerDefinedFunction(
- forward_name, self._graph, self._graph.get_operations(),
+ forward_name, self._graph, function_ops,
self._input_placeholders, filtered_outputs + list(extra_inputs),
self._attrs)
all_inputs = self._out_grad_placeholders + list(extra_placeholders)