aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-10 14:47:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-10 15:08:16 -0800
commit7ff06bc92905013e99df68d18d24de5114c5d32b (patch)
tree2a2151ec347d766a9845daaabf30c1f2889ef492
parente3af7585c04c3c6c875ba6def1c3039c67ae8be4 (diff)
Remove all remaining tf.pack,tf.unpack references and remove tf.pack/tf.unpack
op. Change: 144130931
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py2
-rw-r--r--tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py8
-rw-r--r--tensorflow/python/kernel_tests/BUILD8
-rw-r--r--tensorflow/python/kernel_tests/stack_op_test.py (renamed from tensorflow/python/kernel_tests/pack_op_test.py)39
-rw-r--r--tensorflow/python/kernel_tests/unstack_op_test.py (renamed from tensorflow/python/kernel_tests/unpack_op_test.py)49
-rw-r--r--tensorflow/python/ops/array_ops.py95
6 files changed, 23 insertions, 178 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 33e0424e60..35f4b1071a 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -324,7 +324,7 @@ class UnaryOpsTest(XLATestCase):
def testUnpack(self):
self._testUnary(
- array_ops.unpack,
+ array_ops.unstack,
np.array([[1., 2.], [3., 4.], [5., 6.]], dtype=np.float32),
expected=[
np.array([1., 2.], dtype=np.float32),
diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py
index 93f44f074a..356bfda30e 100644
--- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py
+++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py
@@ -55,7 +55,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None):
predictions.append(nn.softmax(pred))
xent = math_ops.add_n(xent_list, name="sequence_loss/xent")
loss = math_ops.reduce_sum(xent, name="sequence_loss")
- return array_ops_.pack(predictions, axis=1), loss
+ return array_ops_.stack(predictions, axis=1), loss
def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None):
@@ -75,11 +75,11 @@ def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None):
Encoder input from x, and decoder inputs and outputs from y.
"""
with ops.name_scope(name, "seq2seq_inputs", [x, y]):
- in_x = array_ops_.unpack(x, axis=1)
- y = array_ops_.unpack(y, axis=1)
+ in_x = array_ops_.unstack(x, axis=1)
+ y = array_ops_.unstack(y, axis=1)
if not sentinel:
# Set to zeros of shape of y[0], using x for batch size.
- sentinel_shape = array_ops_.pack(
+ sentinel_shape = array_ops_.stack(
[array_ops_.shape(x)[0], y[0].get_shape()[1]])
sentinel = array_ops_.zeros(sentinel_shape)
sentinel.set_shape(y[0].get_shape())
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index a51cb3fe5d..dd9e748035 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1417,9 +1417,9 @@ cuda_py_test(
)
cuda_py_test(
- name = "pack_op_test",
+ name = "stack_op_test",
size = "small",
- srcs = ["pack_op_test.py"],
+ srcs = ["stack_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -1880,9 +1880,9 @@ cuda_py_test(
)
cuda_py_test(
- name = "unpack_op_test",
+ name = "unstack_op_test",
size = "small",
- srcs = ["unpack_op_test.py"],
+ srcs = ["unstack_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
diff --git a/tensorflow/python/kernel_tests/pack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py
index e22e7f1895..afc0c38cac 100644
--- a/tensorflow/python/kernel_tests/pack_op_test.py
+++ b/tensorflow/python/kernel_tests/stack_op_test.py
@@ -39,7 +39,7 @@ def np_split_squeeze(array, axis):
]
-class PackOpTest(test.TestCase):
+class StackOpTest(test.TestCase):
def testSimple(self):
np.random.seed(7)
@@ -50,9 +50,6 @@ class PackOpTest(test.TestCase):
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(constant_op.constant, data))
# Pack back into a single tensorflow tensor
- c = array_ops.pack(xs)
- self.assertAllEqual(c.eval(), data)
-
c = array_ops.stack(xs)
self.assertAllEqual(c.eval(), data)
@@ -65,7 +62,7 @@ class PackOpTest(test.TestCase):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
data = np.random.randn(*shape).astype(np.float32)
# Pack back into a single tensorflow tensor directly using np array
- c = array_ops.pack(data)
+ c = array_ops.stack(data)
# This is implemented via a Const:
self.assertEqual(c.op.type, "Const")
self.assertAllEqual(c.eval(), data)
@@ -76,10 +73,6 @@ class PackOpTest(test.TestCase):
# Python lists also work for 1-D case:
if len(shape) == 1:
data_list = list(data)
- cl = array_ops.pack(data_list)
- self.assertEqual(cl.op.type, "Const")
- self.assertAllEqual(cl.eval(), data)
-
cl = array_ops.stack(data_list)
self.assertEqual(cl.op.type, "Const")
self.assertAllEqual(cl.eval(), data)
@@ -87,11 +80,8 @@ class PackOpTest(test.TestCase):
cl = array_ops.parallel_stack(data_list)
self.assertAllEqual(cl.eval(), data)
- # Verify that shape induction works with shapes produced via const pack
+ # Verify that shape induction works with shapes produced via const stack
a = constant_op.constant([1, 2, 3, 4, 5, 6])
- b = array_ops.reshape(a, array_ops.pack([2, 3]))
- self.assertAllEqual(b.get_shape(), [2, 3])
-
b = array_ops.reshape(a, array_ops.stack([2, 3]))
self.assertAllEqual(b.get_shape(), [2, 3])
@@ -103,10 +93,6 @@ class PackOpTest(test.TestCase):
with self.test_session(use_gpu=True):
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(constant_op.constant, data))
- c = array_ops.pack(xs)
- err = gradient_checker.compute_gradient_error(xs, shapes, c, shape)
- self.assertLess(err, 1e-6)
-
c = array_ops.stack(xs)
err = gradient_checker.compute_gradient_error(xs, shapes, c, shape)
self.assertLess(err, 1e-6)
@@ -121,22 +107,15 @@ class PackOpTest(test.TestCase):
with self.test_session(use_gpu=True):
# TODO(irving): Remove list() once we handle maps correctly
xs = list(map(constant_op.constant, data))
- c = array_ops.pack(xs, axis=1)
- err = gradient_checker.compute_gradient_error(xs, shapes, c, out_shape)
- self.assertLess(err, 1e-6)
-
c = array_ops.stack(xs, axis=1)
err = gradient_checker.compute_gradient_error(xs, shapes, c, out_shape)
self.assertLess(err, 1e-6)
def testZeroSize(self):
- # Verify that pack doesn't crash for zero size inputs
+ # Verify that stack doesn't crash for zero size inputs
with self.test_session(use_gpu=True):
for shape in (0,), (3, 0), (0, 3):
x = np.zeros((2,) + shape).astype(np.int32)
- p = array_ops.pack(list(x)).eval()
- self.assertAllEqual(p, x)
-
p = array_ops.stack(list(x)).eval()
self.assertAllEqual(p, x)
@@ -146,12 +125,9 @@ class PackOpTest(test.TestCase):
def testAxis0Default(self):
with self.test_session(use_gpu=True):
t = [constant_op.constant([1, 2, 3]), constant_op.constant([4, 5, 6])]
-
- packed = array_ops.pack(t).eval()
stacked = array_ops.stack(t).eval()
parallel_stacked = array_ops.parallel_stack(t).eval()
- self.assertAllEqual(packed, np.array([[1, 2, 3], [4, 5, 6]]))
self.assertAllEqual(stacked, np.array([[1, 2, 3], [4, 5, 6]]))
self.assertAllEqual(parallel_stacked, np.array([[1, 2, 3], [4, 5, 6]]))
@@ -165,7 +141,7 @@ class PackOpTest(test.TestCase):
test_arrays = np_split_squeeze(expected, j)
with self.test_session(use_gpu=True):
- actual_pack = array_ops.pack(test_arrays, axis=j)
+ actual_pack = array_ops.stack(test_arrays, axis=j)
self.assertEqual(expected.shape, actual_pack.get_shape())
actual_pack = actual_pack.eval()
@@ -173,21 +149,16 @@ class PackOpTest(test.TestCase):
self.assertEqual(expected.shape, actual_stack.get_shape())
actual_stack = actual_stack.eval()
- self.assertNDArrayNear(expected, actual_pack, 1e-6)
self.assertNDArrayNear(expected, actual_stack, 1e-6)
def testDimOutOfRange(self):
t = [constant_op.constant([1, 2, 3]), constant_op.constant([4, 5, 6])]
with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"):
- array_ops.pack(t, axis=2)
- with self.assertRaisesRegexp(ValueError, r"axis = 2 not in \[-2, 2\)"):
array_ops.stack(t, axis=2)
def testDimOutOfNegativeRange(self):
t = [constant_op.constant([1, 2, 3]), constant_op.constant([4, 5, 6])]
with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"):
- array_ops.pack(t, axis=-3)
- with self.assertRaisesRegexp(ValueError, r"axis = -3 not in \[-2, 2\)"):
array_ops.stack(t, axis=-3)
diff --git a/tensorflow/python/kernel_tests/unpack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py
index 9ba7f1fe5f..c2dcff978a 100644
--- a/tensorflow/python/kernel_tests/unpack_op_test.py
+++ b/tensorflow/python/kernel_tests/unstack_op_test.py
@@ -36,7 +36,7 @@ def np_split_squeeze(array, axis):
]
-class UnpackOpTest(test.TestCase):
+class UnstackOpTest(test.TestCase):
def testSimple(self):
np.random.seed(7)
@@ -46,13 +46,11 @@ class UnpackOpTest(test.TestCase):
# Convert data to a single tensorflow tensor
x = constant_op.constant(data)
# Unpack into a list of tensors
- cs_unpacked = array_ops.unpack(x, num=shape[0])
- cs_unstacked = array_ops.unpack(x, num=shape[0])
- for cs in (cs_unpacked, cs_unstacked):
- self.assertEqual(type(cs), list)
- self.assertEqual(len(cs), shape[0])
- cs = [c.eval() for c in cs]
- self.assertAllEqual(cs, data)
+ cs = array_ops.unstack(x, num=shape[0])
+ self.assertEqual(type(cs), list)
+ self.assertEqual(len(cs), shape[0])
+ cs = [c.eval() for c in cs]
+ self.assertAllEqual(cs, data)
def testGradientsAxis0(self):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
@@ -61,11 +59,6 @@ class UnpackOpTest(test.TestCase):
for i in xrange(shape[0]):
with self.test_session(use_gpu=True):
x = constant_op.constant(data)
- cs = array_ops.unpack(x, num=shape[0])
- err = gradient_checker.compute_gradient_error(x, shape, cs[i],
- shapes[i])
- self.assertLess(err, 1e-6)
-
cs = array_ops.unstack(x, num=shape[0])
err = gradient_checker.compute_gradient_error(x, shape, cs[i],
shapes[i])
@@ -79,11 +72,6 @@ class UnpackOpTest(test.TestCase):
for i in xrange(shape[1]):
with self.test_session(use_gpu=True):
x = constant_op.constant(data)
- cs = array_ops.unpack(x, num=shape[1], axis=1)
- err = gradient_checker.compute_gradient_error(x, shape, cs[i],
- out_shape)
- self.assertLess(err, 1e-6)
-
cs = array_ops.unstack(x, num=shape[1], axis=1)
err = gradient_checker.compute_gradient_error(x, shape, cs[i],
out_shape)
@@ -93,10 +81,6 @@ class UnpackOpTest(test.TestCase):
with self.test_session():
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
x = array_ops.placeholder(np.float32, shape=shape)
- cs = array_ops.unpack(x)
- self.assertEqual(type(cs), list)
- self.assertEqual(len(cs), shape[0])
-
cs = array_ops.unstack(x)
self.assertEqual(type(cs), list)
self.assertEqual(len(cs), shape[0])
@@ -105,23 +89,16 @@ class UnpackOpTest(test.TestCase):
x = array_ops.placeholder(np.float32)
with self.assertRaisesRegexp(ValueError,
r'Cannot infer num from shape <unknown>'):
- array_ops.unpack(x)
- with self.assertRaisesRegexp(ValueError,
- r'Cannot infer num from shape <unknown>'):
array_ops.unstack(x)
def testUnknownShapeOkWithNum(self):
x = array_ops.placeholder(np.float32)
- array_ops.unpack(x, num=2)
array_ops.unstack(x, num=2)
def testCannotInferNumFromNoneShape(self):
x = array_ops.placeholder(np.float32, shape=(None,))
with self.assertRaisesRegexp(ValueError,
r'Cannot infer num from shape \(\?,\)'):
- array_ops.unpack(x)
- with self.assertRaisesRegexp(ValueError,
- r'Cannot infer num from shape \(\?,\)'):
array_ops.unstack(x)
def testAgainstNumpy(self):
@@ -134,22 +111,15 @@ class UnpackOpTest(test.TestCase):
expected = np_split_squeeze(a, j)
with self.test_session() as sess:
- actual_unpack = sess.run(array_ops.unpack(a, axis=j))
actual_unstack = sess.run(array_ops.unstack(a, axis=j))
- self.assertAllEqual(expected, actual_unpack)
self.assertAllEqual(expected, actual_unstack)
def testAxis0Default(self):
with self.test_session() as sess:
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
-
- unpacked = sess.run(array_ops.unpack(a))
unstacked = sess.run(array_ops.unstack(a))
- self.assertEqual(len(unpacked), 2)
- self.assertAllEqual(unpacked[0], [1, 2, 3])
- self.assertAllEqual(unpacked[1], [4, 5, 6])
self.assertEqual(len(unstacked), 2)
self.assertAllEqual(unstacked[0], [1, 2, 3])
self.assertAllEqual(unstacked[1], [4, 5, 6])
@@ -157,23 +127,16 @@ class UnpackOpTest(test.TestCase):
def testAxisOutOfRange(self):
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
- array_ops.unpack(a, axis=2)
- with self.assertRaisesRegexp(ValueError, r'axis = 2 not in \[-2, 2\)'):
array_ops.unstack(a, axis=2)
def testAxisOutOfNegativeRange(self):
a = constant_op.constant([[1, 2, 3], [4, 5, 6]], name='a')
with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
- array_ops.unpack(a, axis=-3)
- with self.assertRaisesRegexp(ValueError, r'axis = -3 not in \[-2, 2\)'):
array_ops.unstack(a, axis=-3)
def testZeroLengthDim(self):
with self.test_session():
x = array_ops.zeros(shape=(0, 1, 2))
- y = array_ops.unpack(x, axis=1)[0].eval()
- self.assertEqual(y.shape, (0, 2))
-
y = array_ops.unstack(x, axis=1)[0].eval()
self.assertEqual(y.shape, (0, 2))
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index b18d6eb491..cea1dd1821 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -57,9 +57,7 @@ or join multiple tensors together.
@@concat_v2
@@stack
@@parallel_stack
-@@pack
@@unstack
-@@unpack
@@reverse_sequence
@@reverse
@@reverse_v2
@@ -492,7 +490,7 @@ def _SliceHelper(tensor, slice_spec, var=None):
shrink_axis_mask |= (1 << index)
index += 1
- # pack possibly involves no tensors, so we must use op_scope correct graph.
+ # stack possibly involves no tensors, so we must use op_scope correct graph.
with ops.name_scope(None, "strided_slice",
[tensor] + begin + end + strides) as name:
if begin:
@@ -1039,50 +1037,6 @@ def stack(values, axis=0, name="stack"):
return gen_array_ops._pack(values, axis=axis, name=name)
-@deprecated(
- "2016-12-14",
- "This op will be removed after the deprecation date. "
- "Please switch to tf.stack().")
-def pack(values, axis=0, name="pack"):
- """Packs a list of rank-`R` tensors into one rank-`(R+1)` tensor.
-
- Packs the list of tensors in `values` into a tensor with rank one higher than
- each tensor in `values`, by packing them along the `axis` dimension.
- Given a list of length `N` of tensors of shape `(A, B, C)`;
-
- if `axis == 0` then the `output` tensor will have the shape `(N, A, B, C)`.
- if `axis == 1` then the `output` tensor will have the shape `(A, N, B, C)`.
- Etc.
-
- For example:
-
- ```prettyprint
- # 'x' is [1, 4]
- # 'y' is [2, 5]
- # 'z' is [3, 6]
- pack([x, y, z]) => [[1, 4], [2, 5], [3, 6]] # Pack along first dim.
- pack([x, y, z], axis=1) => [[1, 2, 3], [4, 5, 6]]
- ```
-
- This is the opposite of unpack. The numpy equivalent is
-
- tf.pack([x, y, z]) = np.asarray([x, y, z])
-
- Args:
- values: A list of `Tensor` objects with the same shape and type.
- axis: An `int`. The axis to pack along. Defaults to the first dimension.
- Supports negative indexes.
- name: A name for this operation (optional).
-
- Returns:
- output: A packed `Tensor` with the same type as `values`.
-
- Raises:
- ValueError: If `axis` is out of the range [-(R+1), R+1).
- """
- return stack(values, axis, name)
-
-
# pylint: disable=invalid-name
def _autopacking_helper(list_or_tuple, dtype, name):
"""Converts the given list or tuple to a tensor by packing.
@@ -1220,49 +1174,6 @@ def unstack(value, num=None, axis=0, name="unstack"):
return gen_array_ops._unpack(value, num=num, axis=axis, name=name)
-@deprecated(
- "2016-12-14",
- "This op will be removed after the deprecation date. "
- "Please switch to tf.unstack().")
-def unpack(value, num=None, axis=0, name="unpack"):
- """Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
-
- Unpacks `num` tensors from `value` by chipping it along the `axis` dimension.
- If `num` is not specified (the default), it is inferred from `value`'s shape.
- If `value.shape[axis]` is not known, `ValueError` is raised.
-
- For example, given a tensor of shape `(A, B, C, D)`;
-
- If `axis == 0` then the i'th tensor in `output` is the slice
- `value[i, :, :, :]` and each tensor in `output` will have shape `(B, C, D)`.
- (Note that the dimension unpacked along is gone, unlike `split`).
-
- If `axis == 1` then the i'th tensor in `output` is the slice
- `value[:, i, :, :]` and each tensor in `output` will have shape `(A, C, D)`.
- Etc.
-
- This is the opposite of pack. The numpy equivalent is
-
- tf.unpack(x, n) = list(x)
-
- Args:
- value: A rank `R > 0` `Tensor` to be unpacked.
- num: An `int`. The length of the dimension `axis`. Automatically inferred
- if `None` (the default).
- axis: An `int`. The axis to unpack along. Defaults to the first
- dimension. Supports negative indexes.
- name: A name for the operation (optional).
-
- Returns:
- The list of `Tensor` objects unpacked from `value`.
-
- Raises:
- ValueError: If `num` is unspecified and cannot be inferred.
- ValueError: If `axis` is out of the range [-R, R).
- """
- return unstack(value, num, axis, name)
-
-
def concat_v2(values, axis, name="concat_v2"):
"""Concatenates tensors along one dimension.
@@ -1296,7 +1207,7 @@ def concat_v2(values, axis, name="concat_v2"):
tf.shape(tf.concat_v2([t3, t4], 1)) ==> [2, 6]
```
- Note: If you are concatenating along a new axis consider using pack.
+ Note: If you are concatenating along a new axis consider using stack.
E.g.
```python
@@ -1306,7 +1217,7 @@ def concat_v2(values, axis, name="concat_v2"):
can be rewritten as
```python
- tf.pack(tensors, axis=axis)
+ tf.stack(tensors, axis=axis)
```
Args: