diff options
author | Alexandre Passos <apassos@google.com> | 2018-08-03 11:43:58 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-03 11:48:05 -0700 |
commit | dd45704d092dac87575b8ce39013f91f4f213dc0 (patch) | |
tree | fa603f0dd9386890638fe3c5cb834a66c1b06888 /tensorflow | |
parent | 4e4171bb6fcc08c00dbc6f6ae2dcdc502add6931 (diff) |
Unbreaks tests broken after the defun while loop change.
Do not add placeholders to the function body as XLA cannot compile them.
PiperOrigin-RevId: 207299427
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/tests/eager_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/eager/function.py | 7 |
2 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py index 116ac472b8..6ead15da13 100644 --- a/tensorflow/compiler/tests/eager_test.py +++ b/tensorflow/compiler/tests/eager_test.py @@ -384,7 +384,6 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertAllEqual([[20., 40.], [90., 120.]], d.numpy()) def testDefunInGradientTape(self): - self.skipTest('b/112172115: Broken by loop support in defun') with self.test_scope(): v0 = resource_variable_ops.ResourceVariable(5.0) @@ -402,7 +401,6 @@ class EagerFunctionTest(xla_test.XLATestCase): self.assertEqual(30, dy.numpy()) def testSliceInDefun(self): - self.skipTest('b/112172115: Broken by loop support in defun') with self.test_scope(): @function.defun diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 8e8c028f60..51ebcd65b3 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -511,8 +511,13 @@ class GraphModeFunction(object): extra_placeholders = [] forward_name = _forward_name(self._func_name) + # Note: we cannot have placeholder ops in the graph or the TPU compilation + # pass fails. + placeholder_ops = set([y.op for y in self._input_placeholders]) + function_ops = [x for x in self._graph.get_operations() + if x not in placeholder_ops] self._forward_fdef = _EagerDefinedFunction( - forward_name, self._graph, self._graph.get_operations(), + forward_name, self._graph, function_ops, self._input_placeholders, filtered_outputs + list(extra_inputs), self._attrs) all_inputs = self._out_grad_placeholders + list(extra_placeholders) |