aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-08 05:12:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-08 05:16:35 -0800
commit51fd9d70b8ef3c11b89e5009357cfbe3abb72473 (patch)
tree75877679d2dfe33fdde9d2345821345d8d7b132d
parent4ac1fee7f13586ce6633a45bbe88592f605583e0 (diff)
Extract the iterated expression of a for loop into a variable to avoid repeated staging.
PiperOrigin-RevId: 188316160
-rw-r--r--tensorflow/contrib/py2tf/converters/builtin_functions.py2
-rw-r--r--tensorflow/contrib/py2tf/converters/for_loops.py30
-rw-r--r--tensorflow/contrib/py2tf/converters/for_loops_test.py23
-rw-r--r--tensorflow/contrib/py2tf/utils/__init__.py2
4 files changed, 44 insertions, 13 deletions
diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py
index b5aa9756da..f1129ef153 100644
--- a/tensorflow/contrib/py2tf/converters/builtin_functions.py
+++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py
@@ -51,7 +51,7 @@ class BuiltinFunctionTransformer(transformer.Base):
def visit_Call(self, node):
self.generic_visit(node)
# TODO(mdan): This won't work if the function was hidden.
- if isinstance(node.func, gast.Name) and node.func.id in ('len',):
+ if isinstance(node.func, gast.Name) and node.func.id in ('len', 'range'):
return self._convert_builtin(node)
# Print needs to be handled separately because it can be read as statement.
if isinstance(node.func, gast.Name) and node.func.id == 'print':
diff --git a/tensorflow/contrib/py2tf/converters/for_loops.py b/tensorflow/contrib/py2tf/converters/for_loops.py
index 935dade0ed..4297c1cf2a 100644
--- a/tensorflow/contrib/py2tf/converters/for_loops.py
+++ b/tensorflow/contrib/py2tf/converters/for_loops.py
@@ -37,14 +37,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
def visit_For(self, node):
self.generic_visit(node)
body_scope = anno.getanno(node, NodeAnno.BODY_SCOPE)
-
+ i_var = self.context.namer.new_symbol('i', body_scope.referenced)
+ n_var = self.context.namer.new_symbol('n', body_scope.referenced)
+ iterated_var = self.context.namer.new_symbol('iterated',
+ body_scope.referenced)
+ # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
if anno.hasanno(node, 'extra_cond'):
template = """
i = 0
- n = len(loop_iter)
+ iterated = loop_iter
+ n = len(iterated)
while i < n and extra_cond:
- # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
- target = loop_iter[i]
+ target = iterated[i]
body
i += 1
"""
@@ -53,17 +57,18 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=self.context.namer.new_symbol('i', body_scope.referenced),
- n=self.context.namer.new_symbol('n', body_scope.referenced),
+ i=i_var,
+ n=n_var,
+ iterated=iterated_var,
extra_cond=anno.getanno(node, 'extra_cond'))
else:
template = """
i = 0
- n = len(loop_iter)
+ iterated = loop_iter
+ n = len(iterated)
while i < n:
- # TODO(mdan): Use TensorListFromTensor(loop_iter) here.
- target = loop_iter[i]
- body # pylint:disable=pointless-statement
+ target = iterated[i]
+ body
i += 1
"""
repl = templates.replace(
@@ -71,8 +76,9 @@ class ForLoopCanonicalizationTransformer(transformer.Base):
loop_iter=node.iter,
target=node.target,
body=node.body,
- i=self.context.namer.new_symbol('i', body_scope.referenced),
- n=self.context.namer.new_symbol('n', body_scope.referenced))
+ i=i_var,
+ n=n_var,
+ iterated=iterated_var)
return repl
def visit_Continue(self, node):
diff --git a/tensorflow/contrib/py2tf/converters/for_loops_test.py b/tensorflow/contrib/py2tf/converters/for_loops_test.py
index 70a367d3b5..b6e3e8c8d8 100644
--- a/tensorflow/contrib/py2tf/converters/for_loops_test.py
+++ b/tensorflow/contrib/py2tf/converters/for_loops_test.py
@@ -42,6 +42,29 @@ class ControlFlowTest(converter_test_base.TestCase):
l = []
self.assertEqual(test_fn(l), result.test_fn(l))
+ def test_for_with_iterated_expression(self):
+
+ eval_count = [0]
+
+ def count_evals(x):
+ eval_count[0] += 1
+ return x
+
+ def test_fn(n):
+ s = 0
+ for e in count_evals(range(n)):
+ s += e
+ return s
+
+ node = self.parse_and_analyze(test_fn, {'count_evals': count_evals})
+ node = for_loops.transform(node, self.ctx)
+
+ with self.compiled(node) as result:
+ result.count_evals = count_evals
+ self.assertEqual(test_fn(5), result.test_fn(5))
+ # count_evals ran twice, once for test_fn and another for result.test_fn
+ self.assertEqual(eval_count[0], 2)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py
index 997c815887..4fc0121efb 100644
--- a/tensorflow/contrib/py2tf/utils/__init__.py
+++ b/tensorflow/contrib/py2tf/utils/__init__.py
@@ -20,11 +20,13 @@ from __future__ import print_function
from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
+from tensorflow.contrib.py2tf.utils.tensor_list import dynamic_list_append
from tensorflow.contrib.py2tf.utils.testing import fake_tf
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
from tensorflow.contrib.py2tf.utils.type_hints import set_element_type