aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-05 15:25:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 15:30:04 -0700
commitea887d2d13a686990145b65e11701deae676b28b (patch)
treed173e6a4d9281764519fff5b4e34b07e98c766fb
parent47fcad59ccd32e38bc133bff25e0645838e3e9df (diff)
Add for and while loops to the list of operators. Do not use them yet.
PiperOrigin-RevId: 191807973
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD17
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py5
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow.py179
-rw-r--r--tensorflow/contrib/autograph/operators/control_flow_test.py82
4 files changed, 282 insertions, 1 deletions
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 7856c253bd..4c62468575 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -2,6 +2,8 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load("//tensorflow:tensorflow.bzl", "py_test")
+
filegroup(
name = "all_files",
srcs = glob(
@@ -18,8 +20,21 @@ py_library(
name = "operators",
srcs = [
"__init__.py",
+ "control_flow.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
- deps = [],
+ deps = [
+ "//tensorflow/contrib/autograph/utils",
+ ],
+)
+
+py_test(
+ name = "control_flow_test",
+ srcs = ["control_flow_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index c3f4cab69e..04b4734551 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -22,3 +22,8 @@ closures for the body.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+# TODO(mdan): Add a container for implementation-specific toggles (throughout).
+
+from tensorflow.contrib.autograph.operators.control_flow import for_loop
+from tensorflow.contrib.autograph.operators.control_flow import while_loop
diff --git a/tensorflow/contrib/autograph/operators/control_flow.py b/tensorflow/contrib/autograph/operators/control_flow.py
new file mode 100644
index 0000000000..5b8cb2d63c
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/control_flow.py
@@ -0,0 +1,179 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Control flow statements: loops, conditionals, etc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.utils import builtins
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_math_ops
+
+
+def for_loop(iterated, extra_cond, loop_body, init_state):
+ """Functional form of a for statement.
+
+ The loop operates on a so-called state, which includes all symbols that are
+ variant across loop iterations, excluding the iterate. In what follows we
+ refer to state as either a tuple of entities that represent an actual state,
+ or a list of arguments of the corresponding types.
+
+ Args:
+ iterated: The entity being iterated over.
+ extra_cond: Callable with the state as arguments, and boolean return type.
+ An additionnal loop condition.
+ loop_body: Callable with the iterate and the state as arguments, and
+ state as return type. The actual loop body.
+ init_state: Tuple containing the initial state.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ if tensor_util.is_tensor(iterated):
+ return _known_len_for_loop(iterated, extra_cond, loop_body, init_state)
+ elif isinstance(iterated, dataset_ops.Dataset):
+ return _dataset_for_loop(iterated, extra_cond, loop_body, init_state)
+ else:
+ return _py_for_loop(iterated, extra_cond, loop_body, init_state)
+
+
+def _py_for_loop(iterated, extra_cond, loop_body, init_state):
+ """Overload of for_loop that executes a Python for loop."""
+ state = init_state
+ for iterate in iterated:
+ if not extra_cond(*state):
+ break
+ state = loop_body(iterate, *state)
+
+ # TODO(mdan): Remove this special case.
+ if len(state) == 1:
+ return state[0]
+ return state
+
+
+def _known_len_for_loop(iterated, extra_cond, loop_body, init_state):
+ """Overload of for_loop that iterates over objects that define a length."""
+ n = builtins.dynamic_len(iterated)
+
+ def while_body(iterate_index, *state):
+ iterate = iterated[iterate_index]
+ new_state = loop_body(iterate, *state)
+ return (iterate_index + 1,) + new_state
+
+ def while_cond(iterate_index, *state):
+ return gen_math_ops.logical_and(iterate_index < n, extra_cond(*state))
+
+ results = while_loop(
+ while_cond,
+ while_body,
+ init_state=(0,) + init_state,
+ extra_deps=(iterated,))
+ # Dropping the iteration index because it's not syntactically visible.
+ results = results[1:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def _dataset_for_loop(ds, extra_cond, loop_body, init_state):
+ """Overload of for_loop that iterates over TF Datasets."""
+ # Because Datsets only expose get_next, in the style of Python iterators,
+ # we are forced to unpack the loop as:
+ #
+ # epoch_number, iterate = ds.get_next()
+ # while epoch_number < 2:
+ # <body>
+ # epoch_number, iterate = ds.get_next()
+ epoch_numbers = dataset_ops.Dataset.range(2)
+ def tag_with(ds, tag):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(tag).repeat(), ds))
+ ds_with_epoch = epoch_numbers.flat_map(lambda i: tag_with(ds, i))
+
+ iterator = ds_with_epoch.make_initializable_iterator()
+ with ops.control_dependencies((iterator.initializer,)):
+ epoch_number, iterate = iterator.get_next()
+
+ def while_body(epoch_number, iterate, *state):
+ new_state = loop_body(iterate, *state)
+ epoch_number, iterate = iterator.get_next()
+ return (epoch_number, iterate) + new_state
+
+ def while_cond(epoch_number, iterate, *state):
+ del iterate
+ return gen_math_ops.logical_and(epoch_number < 1, extra_cond(*state))
+
+ results = while_loop(
+ while_cond,
+ while_body,
+ init_state=(epoch_number, iterate) + init_state,
+ extra_deps=())
+ # Dropping the epoch number and iterate because they are not not syntactically
+ # visible.
+ results = results[2:]
+
+ # TODO(mdan): Remove this special case.
+ if len(results) == 1:
+ return results[0]
+ return results
+
+
+def while_loop(loop_cond, loop_body, init_state, extra_deps):
+ """Functional form of a while statement.
+
+ The loop operates on a so-called state, which includes all symbols that are
+ variant across loop iterations. In what follows we refer to state as either
+ a tuple of entities that represent an actual state, or a list of arguments
+ of the corresponding types.
+
+ Args:
+ loop_cond: Callable with the state as arguments, and boolean return type.
+ The loop condition.
+ loop_body: Callable with the state as arguments, and state as return type.
+ The actual loop body.
+ init_state: Tuple containing the initial state.
+ extra_deps: Tuple containing additional entities on which the loop may
+ depend, such as loop invariants referenced by loop_cond. Used
+ exclusively for dispatch control.
+
+ Returns:
+ Tuple containing the final state.
+ """
+ # TODO(mdan): Consider adding a generic mechanism for dynamic dispatch.
+ # That could be somethins as simple as a collection of dispatch rules, with
+ # some prioritization.
+ if any(tensor_util.is_tensor(v) for v in init_state + extra_deps):
+ return _tf_while_loop(loop_cond, loop_body, init_state)
+ else:
+ return _py_while_loop(loop_cond, loop_body, init_state)
+
+
+def _tf_while_loop(loop_cond, loop_body, init_state):
+ """Overload of while_loop that stages a TF while_loop."""
+ return control_flow_ops.while_loop(loop_cond, loop_body, init_state)
+
+
+def _py_while_loop(loop_cond, loop_body, init_state):
+ """Overload of while_loop that executes a Python while loop."""
+ state = init_state
+ while loop_cond(*state):
+ state = loop_body(*state)
+ return state
diff --git a/tensorflow/contrib/autograph/operators/control_flow_test.py b/tensorflow/contrib/autograph/operators/control_flow_test.py
new file mode 100644
index 0000000000..9112b1627f
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/control_flow_test.py
@@ -0,0 +1,82 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph import operators
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ForLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ s = operators.for_loop(
+ constant_op.constant([1, 2, 3, 4]),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.test_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+ def test_python(self):
+ s = operators.for_loop(
+ range(5),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ self.assertEqual(10, s)
+
+ def test_dataset(self):
+ to_int32 = lambda i: math_ops.cast(i, dtypes.int32)
+ s = operators.for_loop(
+ dataset_ops.Dataset.range(5).map(to_int32),
+ extra_cond=lambda s: True,
+ loop_body=lambda i, s: (s + i,),
+ init_state=(0,))
+ with self.test_session() as sess:
+ self.assertEqual((10,), sess.run(s))
+
+
+class WhileLoopTest(test.TestCase):
+
+ def test_tensor(self):
+ n = constant_op.constant(5)
+ results = operators.while_loop(
+ loop_cond=lambda i, s: i < n,
+ loop_body=lambda i, s: (i + 1, s + i,),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ with self.test_session() as sess:
+ self.assertEqual((5, 10), sess.run(results))
+
+ def test_python(self):
+ n = 5
+ results = operators.while_loop(
+ loop_cond=lambda i, s: i < n,
+ loop_body=lambda i, s: (i + 1, s + i),
+ init_state=(0, 0),
+ extra_deps=(n,))
+ self.assertEqual((5, 10), results)
+
+
+if __name__ == '__main__':
+ test.main()