aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-05-31 10:33:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-31 10:36:34 -0700
commitf50b61fffb7a65688899a625b689387653c5c798 (patch)
tree79872d8babc2de50a6f2e999b54d62c42486bc4d
parent3ff633d9797d173d65523453de589cbbcf6e32ce (diff)
Initial implementation of a few of the list-specific operators.
This introduces an abstraction for a dispatch context, which allows passing local information through to the specialized operators. PiperOrigin-RevId: 198742074
-rw-r--r--tensorflow/contrib/autograph/operators/BUILD12
-rw-r--r--tensorflow/contrib/autograph/operators/__init__.py13
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures.py249
-rw-r--r--tensorflow/contrib/autograph/operators/data_structures_test.py87
-rw-r--r--tensorflow/contrib/autograph/operators/slices.py133
-rw-r--r--tensorflow/contrib/autograph/operators/slices_test.py51
6 files changed, 518 insertions, 27 deletions
diff --git a/tensorflow/contrib/autograph/operators/BUILD b/tensorflow/contrib/autograph/operators/BUILD
index 18bfec5d9c..0c6ab65505 100644
--- a/tensorflow/contrib/autograph/operators/BUILD
+++ b/tensorflow/contrib/autograph/operators/BUILD
@@ -22,7 +22,7 @@ py_library(
"__init__.py",
"control_flow.py",
"data_structures.py",
- "dispatch_context.py",
+ "slices.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
@@ -52,3 +52,13 @@ py_test(
"//tensorflow/python:client_testlib",
],
)
+
+py_test(
+ name = "slices_test",
+ srcs = ["slices_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":operators",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/contrib/autograph/operators/__init__.py b/tensorflow/contrib/autograph/operators/__init__.py
index 38b761d97d..c900fd6af2 100644
--- a/tensorflow/contrib/autograph/operators/__init__.py
+++ b/tensorflow/contrib/autograph/operators/__init__.py
@@ -28,6 +28,10 @@ closures for the body.
# - the names used in the Python docs, if the operator is a function (e.g.
# list_ and x for append, see
# https://docs.python.org/3.7/tutorial/datastructures.html)
+#
+# All operators may accept a final argument named "opts", of a type that
+# subclasses namedtuple and contains any arguments that are only required
+# for some specializations of the operator.
from __future__ import absolute_import
from __future__ import division
@@ -35,3 +39,12 @@ from __future__ import print_function
from tensorflow.contrib.autograph.operators.control_flow import for_stmt
from tensorflow.contrib.autograph.operators.control_flow import while_stmt
+from tensorflow.contrib.autograph.operators.data_structures import list_append
+from tensorflow.contrib.autograph.operators.data_structures import list_pop
+from tensorflow.contrib.autograph.operators.data_structures import list_stack
+from tensorflow.contrib.autograph.operators.data_structures import ListPopOpts
+from tensorflow.contrib.autograph.operators.data_structures import ListStackOpts
+from tensorflow.contrib.autograph.operators.data_structures import new_list
+from tensorflow.contrib.autograph.operators.slices import get_item
+from tensorflow.contrib.autograph.operators.slices import GetItemOpts
+from tensorflow.contrib.autograph.operators.slices import set_item
diff --git a/tensorflow/contrib/autograph/operators/data_structures.py b/tensorflow/contrib/autograph/operators/data_structures.py
index c862306baa..06d8727b0f 100644
--- a/tensorflow/contrib/autograph/operators/data_structures.py
+++ b/tensorflow/contrib/autograph/operators/data_structures.py
@@ -18,39 +18,250 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import variables
+
+
+# TODO(mdan): Once control flow supports objects, repackage as a class.
+
+
+def new_list(iterable=None):
+ """The list constructor.
+
+ Args:
+ iterable: Optional elements to fill the list with.
+
+ Returns:
+ A list-like object. The exact return value depends on the initial elements.
+ """
+ if iterable:
+ elements = tuple(iterable)
+ else:
+ elements = ()
+
+ # TODO(mdan): Extend these criteria.
+ if any(isinstance(el, variables.Variable) for el in elements):
+ return _py_list_new(elements)
+ return _tf_tensor_list_new(elements)
-# TODO(mdan): Add support for TensorList once functional.
-# TODO(mdan): Add primitives for empty list, list with elements.
+def _tf_tensor_list_new(elements):
+ """Overload of new_list that stages a Tensor list creation."""
+ elements = tuple(ops.convert_to_tensor(el) for el in elements)
+ all_dtypes = set(el.dtype for el in elements)
+ if len(all_dtypes) == 1:
+ element_dtype = tuple(all_dtypes)[0]
+ else:
+ # Heterogeneous lists are ok.
+ element_dtype = dtypes.variant
+
+ # TODO(mdan): This may fail for elements of variable shapes.
+ all_shapes = set(tuple(el.shape.as_list()) for el in elements)
+ if len(all_shapes) == 1:
+ element_shape = array_ops.shape(elements[0])
+ else:
+ # Heterogeneous lists are ok.
+ element_shape = constant_op.constant(-1) # unknown shape, by convention
+
+ l = list_ops.empty_tensor_list(
+ element_shape=element_shape, element_dtype=element_dtype)
+ for el in elements:
+ l = list_ops.tensor_list_push_back(l, el)
+ return l
-def append(target, element):
+
+def _py_list_new(elements):
+ """Overload of new_list that creates a Python list."""
+ return list(elements)
+
+
+def list_append(list_, x):
"""The list append function.
- Note: it is unspecified where target will be mutated or not. If target is
- a TensorFlow entity, it will not be typically mutated. If target is a plain
- list, it will be. In general, if the target is mutated then the return value
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
should point to the original entity.
Args:
- target: An entity that supports append semantics.
- element: The element to append.
+ list_: An entity that supports append semantics.
+ x: The element to append.
Returns:
- Same as target, after the append was performed.
+ Same as list_, after the append was performed.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type.
"""
- if isinstance(target, tensor_array_ops.TensorArray):
- return _tf_tensorarray_append(target, element)
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_append(list_, x)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_append(list_, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
else:
- return _py_append(target, element)
+ return _py_list_append(list_, x)
+
+
+def _tf_tensor_list_append(list_, x):
+ """Overload of list_append that stages a Tensor list write."""
+ def empty_list_of_elements_like_x():
+ tensor_x = ops.convert_to_tensor(x)
+ return list_ops.empty_tensor_list(
+ element_shape=array_ops.shape(tensor_x),
+ element_dtype=tensor_x.dtype)
+
+ list_ = control_flow_ops.cond(
+ list_ops.tensor_list_length(list_) > 0,
+ lambda: list_,
+ empty_list_of_elements_like_x,
+ )
+ return list_ops.tensor_list_push_back(list_, x)
+
+
+def _tf_tensorarray_append(list_, x):
+ """Overload of list_append that stages a TensorArray write."""
+ return list_.write(list_.size(), x)
+
+
+def _py_list_append(list_, x):
+ """Overload of list_append that executes a Python list append."""
+ # Revert to the original call.
+ list_.append(x)
+ return list_
+
+
+class ListPopOpts(
+ collections.namedtuple('ListPopOpts', ('element_dtype', 'element_shape'))):
+ pass
+
+
+def list_pop(list_, i, opts):
+ """The list pop function.
+
+ Note: it is unspecified where list_ will be mutated or not. If list_ is
+ a TensorFlow entity, it will not be typically mutated. If list_ is a plain
+ list, it will be. In general, if the list is mutated then the return value
+ should point to the original entity.
+
+ Args:
+ list_: An entity that supports pop semantics.
+ i: Optional index to pop from. May be None.
+ opts: A ListPopOpts.
+
+ Returns:
+ Tuple (x, out_list_):
+ out_list_: same as list_, after the removal was performed.
+ x: the removed element value.
+
+ Raises:
+ ValueError: if list_ is not of a known list-like type or the operation is
+ not supported for that type.
+ """
+ assert isinstance(opts, ListPopOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ raise ValueError('TensorArray does not support item removal')
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_pop(list_, i, opts)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % list_)
+ else:
+ return _py_list_pop(list_, i)
+
+
+def _tf_tensor_list_pop(list_, i, opts):
+ """Overload of list_pop that stages a Tensor list pop."""
+ if i is not None:
+ raise NotImplementedError('tensor lists only support removing from the end')
+
+ if opts.element_dtype is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'type; use set_element_type to annotate it')
+ if opts.element_shape is None:
+ raise ValueError('cannot pop from a list without knowing its element '
+ 'shape; use set_element_type to annotate it')
+ list_out, x = list_ops.tensor_list_pop_back(
+ list_, element_dtype=opts.element_dtype)
+ x.set_shape(opts.element_shape)
+ return list_out, x
+
+
+def _py_list_pop(list_, i):
+ """Overload of list_pop that executes a Python list append."""
+ if i is None:
+ x = list_.pop()
+ else:
+ x = list_.pop(i)
+ return list_, x
+
+
+# TODO(mdan): Look into reducing duplication between all these containers.
+class ListStackOpts(
+ collections.namedtuple('ListStackOpts',
+ ('element_dtype', 'original_call'))):
+ pass
+
+
+def list_stack(list_, opts):
+ """The list stack function.
+
+ This does not have a direct correspondent in Python. The closest idiom to
+ this is tf.append or np.stack. It's different from those in the sense that it
+ accepts a Tensor list, rather than a list of tensors. It can also accept
+ TensorArray. When the target is anything else, the dispatcher will rely on
+ ctx.original_call for fallback.
+
+ Args:
+ list_: An entity that supports append semantics.
+ opts: A ListStackOpts object.
+
+ Returns:
+ The output of the stack operation, typically a Tensor.
+ """
+ assert isinstance(opts, ListStackOpts)
+
+ if isinstance(list_, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_stack(list_)
+ elif tensor_util.is_tensor(list_):
+ if list_.dtype == dtypes.variant:
+ return _tf_tensor_list_stack(list_, opts)
+ else:
+ # No-op for primitive Tensor arguments.
+ return list_
+ else:
+ return _py_list_stack(list_, opts)
+
+
+def _tf_tensorarray_stack(list_):
+ """Overload of list_stack that stages a TensorArray stack."""
+ return list_.stack()
-def _tf_tensorarray_append(target, element):
- """Overload of append that stages a TensorArray write at the last position."""
- return target.write(target.size(), element)
+def _tf_tensor_list_stack(list_, opts):
+ """Overload of list_stack that stages a Tensor list write."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot stack a list without knowing its element type;'
+ ' use set_element_type to annotate it')
+ return list_ops.tensor_list_stack(list_, element_dtype=opts.element_dtype)
-def _py_append(target, element):
- """Overload of append that executes a Python list append."""
- target.append(element)
- return target
+def _py_list_stack(list_, opts):
+ """Overload of list_stack that executes a Python list append."""
+ # Revert to the original call.
+ return opts.original_call(list_)
diff --git a/tensorflow/contrib/autograph/operators/data_structures_test.py b/tensorflow/contrib/autograph/operators/data_structures_test.py
index 577d28c34d..8bbb52d6c1 100644
--- a/tensorflow/contrib/autograph/operators/data_structures_test.py
+++ b/tensorflow/contrib/autograph/operators/data_structures_test.py
@@ -19,25 +19,98 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.autograph.operators import data_structures
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import list_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
-class AppendTest(test.TestCase):
+class ListTest(test.TestCase):
- def test_tf_tensorarray(self):
+ def test_new_list_empty(self):
+ l = data_structures.new_list()
+ # Can't evaluate an empty list.
+ # TODO(mdan): sess.run should allow tf.variant maybe?
+ self.assertTrue(isinstance(l, ops.Tensor))
+
+ def test_new_list_tensor(self):
+ l = data_structures.new_list([3, 4, 5])
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
+ def test_append_tensor_list(self):
+ l = data_structures.new_list()
+ x = constant_op.constant([1, 2, 3])
+ l = data_structures.list_append(l, x)
+
+ t = list_ops.tensor_list_stack(l, element_dtype=x.dtype)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [[1, 2, 3]])
+
+ def test_append_tensorarray(self):
l = tensor_array_ops.TensorArray(dtypes.int32, size=0, dynamic_size=True)
- l1 = data_structures.append(l, 1)
- l2 = data_structures.append(l1, 2)
+ l1 = data_structures.list_append(l, 1)
+ l2 = data_structures.list_append(l1, 2)
with self.test_session() as sess:
self.assertAllEqual(sess.run(l1.stack()), [1])
self.assertAllEqual(sess.run(l2.stack()), [1, 2])
- def test_python(self):
+ def test_append_python(self):
l = []
- self.assertAllEqual(data_structures.append(l, 1), [1])
- self.assertAllEqual(data_structures.append(l, 2), [1, 2])
+ self.assertAllEqual(data_structures.list_append(l, 1), [1])
+ self.assertAllEqual(data_structures.list_append(l, 2), [1, 2])
+
+ def test_pop_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListPopOpts(
+ element_dtype=initial_list.dtype,
+ element_shape=(2,))
+
+ with self.assertRaises(NotImplementedError):
+ data_structures.list_pop(l, 0, opts)
+
+ with self.test_session() as sess:
+ l, x = data_structures.list_pop(l, None, opts)
+ self.assertAllEqual(sess.run(x), [3, 4])
+
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[1, 2]])
+
+ def test_pop_python(self):
+ l = [1, 2, 3]
+ opts = data_structures.ListPopOpts(element_dtype=None, element_shape=())
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1, 2], 3))
+ self.assertAllEqual(data_structures.list_pop(l, None, opts), ([1], 2))
+
+ def test_stack_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=initial_list.dtype, original_call=None)
+
+ with self.test_session() as sess:
+ t = data_structures.list_stack(l, opts)
+ self.assertAllEqual(sess.run(t), sess.run(initial_list))
+
+ def test_stack_fallback(self):
+
+ def dummy_function(l):
+ # Lazy person's mock: just transform the argument in a way in which we
+ # can check that this function was indeed called.
+ return [x * 2 for x in l]
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=None, original_call=dummy_function)
+
+ self.assertAllEqual(data_structures.list_stack([1, 2], opts), [2, 4])
if __name__ == '__main__':
diff --git a/tensorflow/contrib/autograph/operators/slices.py b/tensorflow/contrib/autograph/operators/slices.py
new file mode 100644
index 0000000000..04fbeb2f6e
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/slices.py
@@ -0,0 +1,133 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators specific to slicing operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import list_ops
+from tensorflow.python.ops import tensor_array_ops
+
+
+# TODO(mdan): Support extended slices.
+
+
+class GetItemOpts(collections.namedtuple('GetItemOpts', ('element_dtype',))):
+ pass
+
+
+def get_item(target, i, opts):
+ """The slice read operator (i.e. __getitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports getitem semantics.
+ i: Index to read from.
+ opts: A GetItemOpts object.
+
+ Returns:
+ The read element.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ assert isinstance(opts, GetItemOpts)
+
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_get_item(target, i)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_get_item(target, i, opts)
+ else:
+ return _tf_tensor_get_item(target, i)
+ else:
+ return _py_get_item(target, i)
+
+
+def _tf_tensorarray_get_item(target, i):
+ """Overload of get_item that stages a TensorArray read."""
+ return target.read(i)
+
+
+def _tf_tensor_list_get_item(target, i, opts):
+ """Overload of get_item that stages a Tensor list read."""
+ if opts.element_dtype is None:
+ raise ValueError('cannot retrieve from a list without knowing its '
+ 'element type; use set_element_type to annotate it')
+ x = list_ops.tensor_list_get_item(target, i, element_dtype=opts.element_dtype)
+ return x
+
+
+def _tf_tensor_get_item(target, i):
+ """Overload of get_item that stages a Tensor (not Tensor list) read."""
+ return target[i]
+
+
+def _py_get_item(target, i):
+ """Overload of get_item that executes a Python list modification."""
+ return target[i]
+
+
+def set_item(target, i, x):
+ """The slice write operator (i.e. __setitem__).
+
+ Note: it is unspecified whether target will be mutated or not. In general,
+ if target is mutable (like Python lists), it will be mutated.
+
+ Args:
+ target: An entity that supports setitem semantics.
+ i: Index to modify.
+ x: The new element value.
+
+ Returns:
+ Same as target, after the update was performed.
+
+ Raises:
+ ValueError: if target is not of a supported type.
+ """
+ if isinstance(target, tensor_array_ops.TensorArray):
+ return _tf_tensorarray_set_item(target, i, x)
+ elif tensor_util.is_tensor(target):
+ if target.dtype == dtypes.variant:
+ return _tf_tensor_list_set_item(target, i, x)
+ else:
+ raise ValueError(
+ 'tensor lists are expected to be Tensors with dtype=tf.variant,'
+ ' instead found %s' % target)
+ else:
+ return _py_set_item(target, i, x)
+
+
+def _tf_tensorarray_set_item(target, i, x):
+ """Overload of set_item that stages a TensorArray write."""
+ return target.write(i, x)
+
+
+def _tf_tensor_list_set_item(target, i, x):
+ """Overload of set_item that stages a Tensor list update."""
+ return list_ops.tensor_list_set_item(target, i, x)
+
+
+def _py_set_item(target, i, x):
+ """Overload of set_item that executes a Python list modification."""
+ target[i] = x
+ return target
diff --git a/tensorflow/contrib/autograph/operators/slices_test.py b/tensorflow/contrib/autograph/operators/slices_test.py
new file mode 100644
index 0000000000..d4aacb9d20
--- /dev/null
+++ b/tensorflow/contrib/autograph/operators/slices_test.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slices module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.autograph.operators import slices
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import list_ops
+from tensorflow.python.platform import test
+
+
+class SlicesTest(test.TestCase):
+
+ def test_set_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ l = slices.set_item(l, 0, [5, 6])
+
+ with self.test_session() as sess:
+ t = list_ops.tensor_list_stack(l, element_dtype=initial_list.dtype)
+ self.assertAllEqual(sess.run(t), [[5, 6], [3, 4]])
+
+ def test_get_item_tensor_list(self):
+ initial_list = constant_op.constant([[1, 2], [3, 4]])
+ elem_shape = constant_op.constant([2])
+ l = list_ops.tensor_list_from_tensor(initial_list, element_shape=elem_shape)
+ t = slices.get_item(
+ l, 1, slices.GetItemOpts(element_dtype=initial_list.dtype))
+
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4])
+
+
+if __name__ == '__main__':
+ test.main()