aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-05-01 19:05:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 19:08:50 -0700
commitb50f6325143486eb82b5654f8794f0771b54dd4d (patch)
treea7d8af4150d82f6c594fa91e669e97eb0fdcfb89
parentc0f1080188c5c6955cfa3b3c086ac262b1e5ec02 (diff)
Minor refactor: establish some operator naming conventions and apply them, so that the interface is a bit more consistent.
PiperOrigin-RevId: 195034691
-rw-r--r--tensorflow/contrib/autograph/converters/break_statements.py4
-rw-r--r--tensorflow/contrib/autograph/converters/control_flow.py24
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py16
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py105
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py30
5 files changed, 99 insertions, 80 deletions
diff --git a/tensorflow/contrib/autograph/converters/break_statements.py b/tensorflow/contrib/autograph/converters/break_statements.py
index 91de82f0a7..1be1c96dd3 100644
--- a/tensorflow/contrib/autograph/converters/break_statements.py
+++ b/tensorflow/contrib/autograph/converters/break_statements.py
@@ -114,9 +114,9 @@ class BreakStatementTransformer(transformer.Base):
template,
var_name=break_var,
for_stmt=node)
- extra_cond = templates.replace_as_expression(
+ extra_test = templates.replace_as_expression(
'not var_name', var_name=break_var)
- anno.setanno(node[1], 'extra_cond', extra_cond)
+ anno.setanno(node[1], 'extra_test', extra_test)
return node
diff --git a/tensorflow/contrib/autograph/converters/control_flow.py b/tensorflow/contrib/autograph/converters/control_flow.py
index 2e26cdb3d9..935a2786db 100644
--- a/tensorflow/contrib/autograph/converters/control_flow.py
+++ b/tensorflow/contrib/autograph/converters/control_flow.py
@@ -207,7 +207,7 @@ class ControlFlowTransformer(transformer.Base):
def body_name(state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.while_loop(
+ state_ast_tuple = ag__.while_stmt(
test_name, body_name, (state,), (extra_deps,))
"""
node = templates.replace(
@@ -252,31 +252,31 @@ class ControlFlowTransformer(transformer.Base):
state_ast_tuple = gast.Tuple([n.ast() for n in state], None)
node_body = ast_util.rename_symbols(node.body, ssf_map)
- if anno.hasanno(node, 'extra_cond'):
- extra_cond = anno.getanno(node, 'extra_cond')
- extra_cond = ast_util.rename_symbols(extra_cond, ssf_map)
+ if anno.hasanno(node, 'extra_test'):
+ extra_test = anno.getanno(node, 'extra_test')
+ extra_test = ast_util.rename_symbols(extra_test, ssf_map)
else:
- extra_cond = parser.parse_expression('True')
+ extra_test = parser.parse_expression('True')
template = """
- def extra_cond_name(state_ssf):
- return extra_cond_expr
+ def extra_test_name(state_ssf):
+ return extra_test_expr
def body_name(iterate, state_ssf):
body
return state_ssf,
- state_ast_tuple = ag__.for_loop(
- iterated, extra_cond_name, body_name, (state,))
+ state_ast_tuple = ag__.for_stmt(
+ iter_, extra_test_name, body_name, (state,))
"""
node = templates.replace(
template,
state=state,
state_ssf=state_ssf,
state_ast_tuple=state_ast_tuple,
- iterated=node.iter,
+ iter_=node.iter,
iterate=node.target,
- extra_cond_name=self.context.namer.new_symbol('extra_cond',
+ extra_test_name=self.context.namer.new_symbol('extra_test',
all_referenced),
- extra_cond_expr=extra_cond,
+ extra_test_expr=extra_test,
body_name=self.context.namer.new_symbol('loop_body', all_referenced),
body=node_body)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 04b4734551..38b761d97d 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -19,11 +19,19 @@ conditionals and loops, implemented in functional form, using for example
closures for the body.
"""
+# Naming conventions:
+# * operator names match the name usually used for the respective Python
+# idiom; examples: for_stmt, list_append
+# * operator arguments match either of:
+# - the corresponding Python AST attribute (e.g. the condition of an if
+# statement is called test) if the operator represents an AST construct
+# - the names used in the Python docs, if the operator is a function (e.g.
+# list_ and x for append, see
+# https://docs.python.org/3.7/tutorial/datastructures.html)
+
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# TODO(mdan): Add a container for implementation-specific toggles (throughout).
-
-from tensorflow.contrib.autograph.operators.control_flow import for_loop
-from tensorflow.contrib.autograph.operators.control_flow import while_loop
+from tensorflow.contrib.autograph.operators.control_flow import for_stmt
+from tensorflow.contrib.autograph.operators.control_flow import while_stmt
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
index d9d8b0d593..9f7202821f 100644
--- a/tensorflow/contrib/autograph/operators/control_flow.py
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -25,44 +25,55 @@ from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_math_ops
-# TODO(mdan): Rename _loop to _stmt to follow Python nomenclature.
-# TODO(mdan): Rename arguments to match the AST names.
-
-def for_loop(iterated, extra_cond, loop_body, init_state):
+def for_stmt(iter_, extra_test, body, init_state):
"""Functional form of a for statement.
- The loop operates on a so-called state, which includes all symbols that are
- variant across loop iterations, excluding the iterate. In what follows we
- refer to state as either a tuple of entities that represent an actual state,
- or a list of arguments of the corresponding types.
+ The loop operates on a state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate as well as the
+ variables local to the loop.
+
+ For example, given the loop below that calculates the geometric and
+ arithmetic means or some numbers:
+
+ geo_mean = 1
+ arith_mean = 0
+ for i in range(n):
+ a = numbers[i]
+ geo_mean *= a
+ arith_mean += a
+
+ The state is represented by the variables geo_mean and arith_mean. The
+ argument for initial_state may contain the tuple (1, 0), the body will
+ include the arguments geo_mean and arith_mean and will return a tuple
+ representing the new values for geo_mean and respectively arith_mean.
Args:
- iterated: The entity being iterated over.
- extra_cond: Callable with the state as arguments, and boolean return type.
+ iter_: The entity being iterated over.
+ extra_test: Callable with the state as arguments, and boolean return type.
An additionnal loop condition.
- loop_body: Callable with the iterate and the state as arguments, and
+ body: Callable with the iterate and the state as arguments, and
state as return type. The actual loop body.
init_state: Tuple containing the initial state.
Returns:
Tuple containing the final state.
"""
- if tensor_util.is_tensor(iterated):
- return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
- elif isinstance(iterated, dataset_ops.Dataset):
- return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+ if tensor_util.is_tensor(iter_):
+ return _known_len_for_stmt(iter_, extra_test, body, init_state)
+ elif isinstance(iter_, dataset_ops.Dataset):
+ return _dataset_for_stmt(iter_, extra_test, body, init_state)
else:
- return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+ return _py_for_stmt(iter_, extra_test, body, init_state)
-def _py_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that executes a Python for loop."""
+def _py_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that executes a Python for loop."""
state = init_state
- for iterate in iterated:
- if not extra_cond(*state):
+ for target in iter_:
+ if not extra_test(*state):
break
- state = loop_body(iterate, *state)
+ state = body(target, *state)
# TODO(mdan): Remove this special case.
if len(state) == 1:
@@ -70,23 +81,23 @@ def _py_for_loop(iterated, extra_cond, loop_body, init_state):
return state
-def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over objects that define a length."""
- n = builtins.dynamic_len(iterated)
+def _known_len_for_stmt(iter_, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over objects that define a length."""
+ n = builtins.dynamic_len(iter_)
def while_body(iterate_index, *state):
- iterate = iterated[iterate_index]
- new_state = loop_body(iterate, *state)
+ iterate = iter_[iterate_index]
+ new_state = body(iterate, *state)
return (iterate_index + 1,) + new_state
def while_cond(iterate_index, *state):
- return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+ return gen_math_ops.logical_and(iterate_index < n, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(0,) + init_state,
- extra_deps=(iterated,),
+ extra_deps=(iter_,),
opts=dict(maximum_iterations=n))
# Dropping the iteration index because it's not syntactically visible.
results = results[1:]
@@ -97,8 +108,8 @@ def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
return results
-def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
- """Overload of for_loop that iterates over TF Datasets."""
+def _dataset_for_stmt(ds, extra_test, body, init_state):
+ """Overload of for_stmt that iterates over TF Datasets."""
# Because Datsets only expose get_next, in the style of Python iterators,
# we are forced to unpack the loop as:
#
@@ -117,15 +128,15 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
epoch_number, iterate = iterator.get_next()
def while_body(epoch_number, iterate, *state):
- new_state = loop_body(iterate, *state)
+ new_state = body(iterate, *state)
epoch_number, iterate = iterator.get_next()
return (epoch_number, iterate) + new_state
def while_cond(epoch_number, iterate, *state):
del iterate
- return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+ return gen_math_ops.logical_and(epoch_number < 1, extra_test(*state))
- results = while_loop(
+ results = while_stmt(
while_cond,
while_body,
init_state=(epoch_number, iterate) + init_state,
@@ -140,7 +151,7 @@ def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
return results
-def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
+def while_stmt(test, body, init_state, extra_deps, opts=None):
"""Functional form of a while statement.
The loop operates on a so-called state, which includes all symbols that are
@@ -149,13 +160,13 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
of the corresponding types.
Args:
- loop_cond: Callable with the state as arguments, and boolean return type.
+ test: Callable with the state as arguments, and boolean return type.
The loop condition.
- loop_body: Callable with the state as arguments, and state as return type.
+ body: Callable with the state as arguments, and state as return type.
The actual loop body.
init_state: Tuple containing the initial state.
extra_deps: Tuple containing additional entities on which the loop may
- depend, such as loop invariants referenced by loop_cond. Used
+ depend, such as loop invariants referenced by test. Used
exclusively for dispatch control.
opts: Optional dict of extra loop parameters.
@@ -166,24 +177,24 @@ def while_loop(loop_cond, loop_body, init_state, extra_deps, opts=None):
# That could be somethins as simple as a collection of dispatch rules, with
# some prioritization.
if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
- return _tf_while_loop(loop_cond, loop_body, init_state, opts)
+ return _tf_while_stmt(test, body, init_state, opts)
else:
- return _py_while_loop(loop_cond, loop_body, init_state, opts)
+ return _py_while_stmt(test, body, init_state, opts)
-def _tf_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that stages a TF while_loop."""
+def _tf_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that stages a TF while_stmt."""
if opts is None:
opts = {}
- return control_flow_ops.while_loop(loop_cond, loop_body, init_state, **opts)
+ return control_flow_ops.while_loop(test, body, init_state, **opts)
-def _py_while_loop(loop_cond, loop_body, init_state, opts):
- """Overload of while_loop that executes a Python while loop."""
+def _py_while_stmt(test, body, init_state, opts):
+ """Overload of while_stmt that executes a Python while loop."""
del opts
state = init_state
- while loop_cond(*state):
- state = loop_body(*state)
+ while test(*state):
+ state = body(*state)
return state
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
index a0cd0bfa82..b14d7edba3 100644
--- a/tensorflow/contrib/autograph/operators/control_flow_test.py
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -29,28 +29,28 @@ from tensorflow.python.platform import test
class ForLoopTest(test.TestCase):
def test_tensor(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
constant_op.constant([1, 2, 3, 4]),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
def test_python(self):
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
range(5),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
self.assertEqual(10, s)
def test_dataset(self):
to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
- s = control_flow.for_loop(
+ s = control_flow.for_stmt(
dataset_ops.Dataset.range(5).map(to_int32),
- extra_cond=lambda s: True,
- loop_body=lambda i, s: (s + i,),
+ extra_test=lambda s: True,
+ body=lambda i, s: (s + i,),
init_state=(0,))
with self.test_session() as sess:
self.assertEqual((10,), sess.run(s))
@@ -60,9 +60,9 @@ class WhileLoopTest(test.TestCase):
def test_tensor(self):
n = constant_op.constant(5)
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i,),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i,),
init_state=(0, 0),
extra_deps=(n,))
with self.test_session() as sess:
@@ -70,9 +70,9 @@ class WhileLoopTest(test.TestCase):
def test_python(self):
n = 5
- results = control_flow.while_loop(
- loop_cond=lambda i, s: i < n,
- loop_body=lambda i, s: (i + 1, s + i),
+ results = control_flow.while_stmt(
+ test=lambda i, s: i < n,
+ body=lambda i, s: (i + 1, s + i),
init_state=(0, 0),
extra_deps=(n,))
self.assertEqual((5, 10), results)