aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-14 23:45:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 00:04:59 -0800
commitc5dc750ba9fab7e7f1f05ee0e0cdb04ae96e0e32 (patch)
tree937edf17553f8d1f24abaf683dc83b10e7e730f4
parent3bb102941e638617894facca6859b444154f8c2b (diff)
Switch array_ops.pack/unpack to array_ops.stack/unstack. Also switch a few remaining references to tf.pack/unpack to tf.stack/unstack.
Change: 142108785
-rw-r--r--tensorflow/contrib/distributions/python/ops/distribution_util.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn.py4
-rw-r--r--tensorflow/contrib/integrate/__init__.py4
-rw-r--r--tensorflow/contrib/integrate/python/ops/odes_test.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/embedding_ops.py2
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py2
-rw-r--r--tensorflow/contrib/layers/python/ops/sparse_ops.py2
-rw-r--r--tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_composition.py7
-rw-r--r--tensorflow/contrib/metrics/python/ops/metric_ops.py4
-rw-r--r--tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py4
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py6
-rw-r--r--tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py4
-rw-r--r--tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py2
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py6
-rw-r--r--tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py3
-rw-r--r--tensorflow/python/layers/convolutional.py2
-rw-r--r--tensorflow/python/layers/core.py2
-rw-r--r--tensorflow/python/ops/array_grad.py17
-rw-r--r--tensorflow/python/ops/clip_ops.py2
-rw-r--r--tensorflow/python/ops/confusion_matrix.py4
-rw-r--r--tensorflow/python/ops/embedding_ops.py2
-rw-r--r--tensorflow/python/ops/gradients_impl.py2
-rw-r--r--tensorflow/python/ops/gradients_test.py14
-rw-r--r--tensorflow/python/ops/image_ops_impl.py27
-rw-r--r--tensorflow/python/ops/image_ops_test.py4
-rw-r--r--tensorflow/python/ops/metrics.py2
-rw-r--r--tensorflow/python/ops/nn_grad.py2
-rw-r--r--tensorflow/python/ops/nn_impl.py6
-rw-r--r--tensorflow/python/ops/nn_ops.py4
-rw-r--r--tensorflow/python/ops/resources.py4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py9
-rw-r--r--tensorflow/python/ops/sparse_ops.py4
-rw-r--r--tensorflow/python/ops/variables.py7
36 files changed, 94 insertions, 89 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py
index 0b00d16e60..671aa6a513 100644
--- a/tensorflow/contrib/distributions/python/ops/distribution_util.py
+++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py
@@ -549,9 +549,10 @@ def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"):
batch_ids = math_ops.range(m)
# Assemble the tril_ids into batch,tril_id pairs.
- idx = array_ops.pack([
- array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n*n]),
- array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])])
+ idx = array_ops.stack([
+ array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]),
+ array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])
+ ])
idx = array_ops.transpose(idx, [1, 2, 0])
# Gather up, reshape, and return.
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index a0d61ef558..63d67fe67f 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -232,7 +232,7 @@ class Mixture(distribution.Distribution):
cat_lp + d_lp
for (cat_lp, d_lp) in zip(cat_log_probs, distribution_log_probs)
]
- concat_log_probs = array_ops.pack(final_log_probs, 0)
+ concat_log_probs = array_ops.stack(final_log_probs, 0)
log_sum_exp = math_ops.reduce_logsumexp(concat_log_probs, [0])
return log_sum_exp
diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py
index 4341acbba1..0595ca89d8 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn.py
@@ -222,7 +222,7 @@ class _MultivariateNormalOperatorPD(distribution.Distribution):
return self._cov.get_batch_shape()
def _event_shape(self):
- return array_ops.pack([self._cov.vector_space_dimension()])
+ return array_ops.stack([self._cov.vector_space_dimension()])
def _get_event_shape(self):
return self._cov.get_shape()[-1:]
@@ -240,7 +240,7 @@ class _MultivariateNormalOperatorPD(distribution.Distribution):
# Move the last dimension to the front
perm = array_ops.concat_v2(
- (array_ops.pack([array_ops.rank(correlated_samples) - 1]),
+ (array_ops.stack([array_ops.rank(correlated_samples) - 1]),
math_ops.range(0, array_ops.rank(correlated_samples) - 1)), 0)
# TODO(ebrevdo): Once we get a proper tensor contraction op,
diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py
index e88d10c582..953dc6c55a 100644
--- a/tensorflow/contrib/integrate/__init__.py
+++ b/tensorflow/contrib/integrate/__init__.py
@@ -27,11 +27,11 @@ sigma = 10.0
beta = 8.0/3.0
def lorenz_equation(state, t):
- x, y, z = tf.unpack(state)
+ x, y, z = tf.unstack(state)
dx = sigma * (y - x)
dy = x * (rho - z) - y
dz = x * y - beta * z
- return tf.pack([dx, dy, dz])
+ return tf.stack([dx, dy, dz])
init_state = tf.constant([0, 2, 20], dtype=tf.float64)
t = np.linspace(0, 50, num=5000)
diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py
index cb036bf05a..55d92fe9cf 100644
--- a/tensorflow/contrib/integrate/python/ops/odes_test.py
+++ b/tensorflow/contrib/integrate/python/ops/odes_test.py
@@ -214,8 +214,8 @@ class InterpolationTest(tf.test.TestCase):
coeffs = odes._interp_fit(
f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0)
times = np.linspace(0, 10, dtype=np.float32)
- y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t)
- for t in times])
+ y_fit = tf.stack(
+ [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times])
y_expected = f(times)
with self.test_session() as sess:
y_actual = sess.run(y_fit)
diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py
index 490156806f..fd4eda4ec2 100644
--- a/tensorflow/contrib/layers/python/layers/embedding_ops.py
+++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py
@@ -149,7 +149,7 @@ def safe_embedding_lookup_sparse(embedding_weights,
# for use in Select.
is_row_empty = array_ops.tile(
array_ops.reshape(is_row_empty, [-1, 1]),
- array_ops.pack([1, array_ops.shape(result)[1]]))
+ array_ops.stack([1, array_ops.shape(result)[1]]))
result = array_ops.where(is_row_empty,
array_ops.zeros_like(result),
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 4a414eb1d1..7547da1080 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -1688,8 +1688,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple(
i2 = array_ops.zeros([batch_size], dtype=dtypes.int32, name="zeros")
bucket_indices = array_ops.reshape(input_tensor, [-1], name="reshape")
- indices = math_ops.to_int64(array_ops.transpose(array_ops.pack((i1, i2))))
- shape = math_ops.to_int64(array_ops.pack([batch_size, dimension]))
+ indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2))))
+ shape = math_ops.to_int64(array_ops.stack([batch_size, dimension]))
sparse_id_values = sparse_tensor_py.SparseTensor(
indices, bucket_indices, shape)
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 47100d7d05..4d54886a94 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -2113,7 +2113,7 @@ def legacy_fully_connected(x,
out_shape = array_ops.unpack(array_ops.shape(x))
out_shape[-1] = num_output_units
- y = array_ops.reshape(y, array_ops.pack(out_shape))
+ y = array_ops.reshape(y, array_ops.stack(out_shape))
static_shape = x.get_shape().as_list()
static_shape[-1] = num_output_units
diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops.py b/tensorflow/contrib/layers/python/ops/sparse_ops.py
index 325f5ac97b..f4d67b370a 100644
--- a/tensorflow/contrib/layers/python/ops/sparse_ops.py
+++ b/tensorflow/contrib/layers/python/ops/sparse_ops.py
@@ -73,7 +73,7 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
# Computes the correct flattened indices for 2d (or higher) tensors.
if index_dims > 1:
higher_dims = indices[:, :index_dims - 1]
- shape_multipliers = array_ops.pack(
+ shape_multipliers = array_ops.stack(
_multiplier_helper(array_ops.unpack(dense_shape)[1:]))
offsets = math_ops.reduce_sum(
math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1])
diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
index 0582028b88..d2018bb219 100644
--- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
+++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py
@@ -654,7 +654,7 @@ def attention_decoder(decoder_inputs,
outputs = []
prev = None
- batch_attn_size = array_ops.pack([batch_size, attn_size])
+ batch_attn_size = array_ops.stack([batch_size, attn_size])
attns = [
array_ops.zeros(
batch_attn_size, dtype=dtype) for _ in xrange(num_heads)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
index ddc4fa9588..94299b451d 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py
@@ -211,9 +211,10 @@ class LinearOperatorComposition(linear_operator.LinearOperator):
# Don't check the matrix dimensions. That would add unnecessary Asserts to
# the graph. Things will fail at runtime naturally if shapes are
# incompatible.
- matrix_shape = array_ops.pack(
- [self.operators[0].range_dimension_dynamic(),
- self.operators[-1].domain_dimension_dynamic()])
+ matrix_shape = array_ops.stack([
+ self.operators[0].range_dimension_dynamic(),
+ self.operators[-1].domain_dimension_dynamic()
+ ])
# Dummy Tensor of zeros. Will never be materialized.
zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic())
diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py
index 88036c87dc..9c5ffc293f 100644
--- a/tensorflow/contrib/metrics/python/ops/metric_ops.py
+++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py
@@ -620,7 +620,7 @@ def _streaming_confusion_matrix_at_thresholds(
num_predictions = array_ops.shape(predictions_2d)[0]
thresh_tiled = array_ops.tile(
array_ops.expand_dims(array_ops.constant(thresholds), [1]),
- array_ops.pack([1, num_predictions]))
+ array_ops.stack([1, num_predictions]))
# Tile the predictions after thresholding them across different thresholds.
pred_is_pos = math_ops.greater(
@@ -2605,7 +2605,7 @@ def streaming_concat(values,
def reallocate():
next_size = _next_array_size(new_size)
- next_shape = array_ops.pack([next_size] + fixed_shape)
+ next_shape = array_ops.stack([next_size] + fixed_shape)
new_value = array_ops.zeros(next_shape, dtype=values.dtype)
old_value = array.value()
assign_op = state_ops.assign(array, new_value, validate_shape=False)
diff --git a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
index 4306b68e7d..465607ff51 100644
--- a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py
@@ -100,7 +100,7 @@ class FusedRNNCellAdaptor(FusedRNNCell):
is_list = isinstance(inputs, list)
if self._use_dynamic_rnn:
if is_list:
- inputs = array_ops.pack(inputs)
+ inputs = array_ops.stack(inputs)
outputs, state = rnn.dynamic_rnn(
self._cell,
inputs,
@@ -123,7 +123,7 @@ class FusedRNNCellAdaptor(FusedRNNCell):
scope=scope)
if not is_list:
# Convert outputs back to tensor
- outputs = array_ops.pack(outputs)
+ outputs = array_ops.stack(outputs)
return outputs, state
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 4ad269ab4f..875d5e860f 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -206,7 +206,7 @@ def _block_lstm(seq_len_max,
# pylint: disable=protected-access
i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm(
seq_len_max=seq_len_max,
- x=array_ops.pack(x),
+ x=array_ops.stack(x),
cs_prev=cs_prev,
h_prev=h_prev,
w=w,
@@ -480,7 +480,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
with vs.variable_scope(scope or "lstm_block_wrapper"):
is_list = isinstance(inputs, list)
if is_list:
- inputs = array_ops.pack(inputs)
+ inputs = array_ops.stack(inputs)
inputs_shape = inputs.get_shape().with_rank(3)
if not inputs_shape[2]:
raise ValueError("Expecting inputs_shape[2] to be set: %s" %
@@ -498,7 +498,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
raise ValueError(
"Either initial_state or dtype needs to be specified")
z = array_ops.zeros(
- array_ops.pack([batch_size, self.num_units]), dtype=dtype)
+ array_ops.stack([batch_size, self.num_units]), dtype=dtype)
initial_state = z, z
else:
if len(initial_state) != 2:
diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
index a2c71ae334..bf6f14e8f5 100644
--- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
+++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py
@@ -192,7 +192,7 @@ class Tensor(ItemHandler):
if isinstance(shape_dim, sparse_tensor.SparseTensor):
shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim)
shape_dims.append(shape_dim)
- shape = array_ops.reshape(array_ops.pack(shape_dims), [-1])
+ shape = array_ops.reshape(array_ops.stack(shape_dims), [-1])
if isinstance(tensor, sparse_tensor.SparseTensor):
if shape is not None:
tensor = sparse_ops.sparse_reshape(tensor, shape)
@@ -251,7 +251,7 @@ class SparseTensor(ItemHandler):
rank = indices_shape[1]
ids = math_ops.to_int64(indices.values)
indices_columns_to_preserve = array_ops.slice(
- indices.indices, [0, 0], array_ops.pack([-1, rank - 1]))
+ indices.indices, [0, 0], array_ops.stack([-1, rank - 1]))
new_indices = array_ops.concat_v2(
[indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)
diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
index ea9b9ce6de..59bc3fbc22 100644
--- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
+++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py
@@ -67,7 +67,7 @@ class HybridModel(object):
# results.
if isinstance(layer, collections.Iterable):
return math_ops.reduce_mean(
- array_ops.pack([l.inference_graph(data) for l in layer]), 0)
+ array_ops.stack([l.inference_graph(data) for l in layer]), 0)
# If this is a single layer, return its inference result.
else:
return layer.inference_graph(data)
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 5dff5f5f31..23873ba3be 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -416,7 +416,7 @@ class RandomForestGraphs(object):
probabilities.append(self.trees[i].inference_graph(
tree_data, data_spec, **inference_args))
with ops.device(self.device_assigner.get_device(0)):
- all_predict = array_ops.pack(probabilities)
+ all_predict = array_ops.stack(probabilities)
return math_ops.div(
math_ops.reduce_sum(all_predict, 0), self.params.num_trees,
name='probabilities')
@@ -431,7 +431,7 @@ class RandomForestGraphs(object):
for i in range(self.params.num_trees):
with ops.device(self.device_assigner.get_device(i)):
sizes.append(self.trees[i].size())
- return math_ops.reduce_mean(math_ops.to_float(array_ops.pack(sizes)))
+ return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes)))
# pylint: disable=unused-argument
def training_loss(self, features, labels, data_spec=None,
@@ -452,7 +452,7 @@ class RandomForestGraphs(object):
for i in range(self.params.num_trees):
with ops.device(self.device_assigner.get_device(i)):
impurities.append(self.trees[i].average_impurity())
- return math_ops.reduce_mean(array_ops.pack(impurities))
+ return math_ops.reduce_mean(array_ops.stack(impurities))
def get_stats(self, session):
tree_stats = []
diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
index b978daed3e..83f71856f0 100644
--- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
+++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py
@@ -1164,7 +1164,8 @@ class SequenceQueueingStateSaver(object):
insert_initial_state_ops = dict(
(name, self._barrier.insert_many(
self._get_barrier_index("state", name),
- array_ops.pack([current_keys[0]]), array_ops.pack([value]),
+ array_ops.stack([current_keys[0]]),
+ array_ops.stack([value]),
name="BarrierInitialInsertState_%s" % name))
for (name, value) in self._uninitialized_states.items())
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 711ea278c6..bcc4565b1a 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -1084,7 +1084,7 @@ class Conv2DTranspose(Conv2D):
output_shape = (batch_size, out_height, out_width, self.filters)
strides = (1, stride_h, stride_w, 1)
- output_shape_tensor = array_ops.pack(output_shape)
+ output_shape_tensor = array_ops.stack(output_shape)
outputs = nn.conv2d_transpose(
inputs,
self.kernel,
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 1ccf83441e..5c74b869d9 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -137,7 +137,7 @@ class Dense(base._Layer): # pylint: disable=protected-access
# Reshape the input to 2D.
output_shape_tensors = array_ops.unpack(array_ops.shape(inputs))
output_shape_tensors[-1] = self.units
- output_shape_tensor = array_ops.pack(output_shape_tensors)
+ output_shape_tensor = array_ops.stack(output_shape_tensors)
inputs = array_ops.reshape(inputs, [-1, input_dim])
outputs = standard_ops.matmul(inputs, self.w)
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index f828490647..ad683730c7 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -40,7 +40,7 @@ def _PackGrad(op, grad):
@ops.RegisterGradient("Unpack")
def _UnpackGrad(op, *grads):
"""Gradient for unpack op."""
- return array_ops.pack(grads, axis=op.get_attr("axis"))
+ return array_ops.stack(grads, axis=op.get_attr("axis"))
def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
@@ -114,9 +114,10 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index):
# pylint: disable=protected-access
if len(sizes) > 16:
# extract the size of each input along the concat dimension
- sizes = array_ops.squeeze(array_ops.slice(array_ops.pack(sizes, axis=1),
- [concat_dim, 0],
- [1, -1]))
+ sizes = array_ops.squeeze(
+ array_ops.slice(
+ array_ops.stack(
+ sizes, axis=1), [concat_dim, 0], [1, -1]))
out_grads = array_ops.split(grad, sizes, concat_dim)
else:
offset = gen_array_ops._concat_offset(concat_dim, sizes)
@@ -206,7 +207,7 @@ def _SliceGrad(op, grad):
input_rank = array_ops.rank(input_vec)
slice_size = array_ops.shape(op.outputs[0])
- shape = array_ops.pack([input_rank, 1])
+ shape = array_ops.stack([input_rank, 1])
before_pad = array_ops.reshape(begin_vec, shape)
after_pad = array_ops.reshape(
array_ops.shape(input_vec) - slice_size - begin_vec, shape)
@@ -435,8 +436,8 @@ def _TileGrad(op, grad):
# multiples = [2, 3, 4]
# split_shape = [2, 20, 3, 30, 4, 40]
# axes = [0, 2, 4]
- split_shape = array_ops.reshape(array_ops.transpose(
- array_ops.pack([op.inputs[1], input_shape])), [-1])
+ split_shape = array_ops.reshape(
+ array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1])
axes = math_ops.range(0, array_ops.size(split_shape), 2)
input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes)
# Fix shape inference
@@ -456,7 +457,7 @@ def _PadGrad(op, grad):
a = op.inputs[1] # [Rank(x), 2]
# Takes a slice of a. The 1st column. [Rank(x), 1].
pad_before = array_ops.slice(a, [0, 0],
- array_ops.pack([array_ops.rank(x), 1]))
+ array_ops.stack([array_ops.rank(x), 1]))
# Make it a 1-D tensor.
begin = array_ops.reshape(pad_before, [-1])
sizes = array_ops.shape(x)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index e0e2fb1304..bda7212c8a 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -142,7 +142,7 @@ def global_norm(t_list, name=None):
with ops.colocate_with(v):
half_squared_norms.append(gen_nn_ops.l2_loss(v))
- half_squared_norm = math_ops.reduce_sum(array_ops.pack(half_squared_norms))
+ half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms))
norm = math_ops.sqrt(
half_squared_norm *
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index 576b78b15f..a45397b3af 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -152,8 +152,8 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
predictions.get_shape().assert_is_compatible_with(weights.get_shape())
weights = math_ops.cast(weights, dtype)
- shape = array_ops.pack([num_classes, num_classes])
- indices = array_ops.transpose(array_ops.pack([predictions, labels]))
+ shape = array_ops.stack([num_classes, num_classes])
+ indices = array_ops.transpose(array_ops.stack([predictions, labels]))
values = (array_ops.ones_like(predictions, dtype)
if weights is None else weights)
cm_sparse = sparse_tensor.SparseTensor(
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index aae65b194b..95bba3efc5 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -137,7 +137,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
with ops.colocate_with(params[p]):
dim_0_sizes.append(array_ops.shape(params[p])[0])
num_total_ids = math_ops.reduce_sum(
- math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype))
+ math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype))
ids_per_partition = num_total_ids // np
extras = num_total_ids % np
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 4075da9900..4010b95372 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -891,5 +891,5 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False,
# Compute the partial derivatives with respect to each element of the list
_hess = [gradients(_gradient, x, **kwargs)[0] for _gradient in _gradients]
# Pack the list into a matrix and add to the list of hessians
- hessians.append(array_ops.pack(_hess, name=name))
+ hessians.append(array_ops.stack(_hess, name=name))
return hessians
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 8b4fbeb7a9..91b147a3c1 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -86,7 +86,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g:
t1 = constant(1.0)
t2 = constant(2.0)
- t3 = array_ops.pack([t1, t2])
+ t3 = array_ops.stack([t1, t2])
# Full graph
self._assertOpListEqual([t3.op, t2.op, t1.op],
_OpsBetween(g, [t3.op], [t1.op, t2.op]))
@@ -98,10 +98,10 @@ class GradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g:
t1 = constant(1.0)
t2 = constant(2.0)
- _ = array_ops.pack([t1, t2])
+ _ = array_ops.stack([t1, t2])
t4 = constant(1.0)
t5 = constant(2.0)
- t6 = array_ops.pack([t4, t5])
+ t6 = array_ops.stack([t4, t5])
# Elements of to_ops are always listed.
self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op]))
@@ -109,7 +109,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g:
t1 = constant(1.0)
t2 = constant(2.0)
- t3 = array_ops.pack([t1, t2])
+ t3 = array_ops.stack([t1, t2])
t4 = constant([1.0])
t5 = array_ops.concat_v2([t4, t3], 0)
t6 = constant([2.0])
@@ -121,7 +121,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
with ops.Graph().as_default() as g:
t1 = constant(1.0)
t2 = constant(2.0)
- t3 = array_ops.pack([t1, t2])
+ t3 = array_ops.stack([t1, t2])
t4 = array_ops.concat_v2([t3, t3, t3], 0)
t5 = constant([1.0])
t6 = array_ops.concat_v2([t4, t5], 0)
@@ -491,8 +491,8 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
numpy_list.append(np_val)
dense_list.append(c)
sparse_list.append(c_sparse)
- packed_dense = array_ops.pack(dense_list)
- packed_sparse = array_ops.pack(sparse_list)
+ packed_dense = array_ops.stack(dense_list)
+ packed_sparse = array_ops.stack(sparse_list)
self.assertAllClose(packed_dense.eval(), packed_sparse.eval())
def testInt64Indices(self):
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 15694d4b3f..209cf5aa63 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -372,8 +372,8 @@ def central_crop(image, central_fraction):
bbox_h_size = img_shape[0] - bbox_h_start * 2
bbox_w_size = img_shape[1] - bbox_w_start * 2
- bbox_begin = array_ops.pack([bbox_h_start, bbox_w_start, 0])
- bbox_size = array_ops.pack([bbox_h_size, bbox_w_size, -1])
+ bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0])
+ bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1])
image = array_ops.slice(image, bbox_begin, bbox_size)
# The first two dimensions are dynamic and unknown.
@@ -428,10 +428,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height,
# Do not pad on the depth dimensions.
paddings = array_ops.reshape(
- array_ops.pack([offset_height, after_padding_height,
- offset_width, after_padding_width,
- 0, 0]),
- [3, 2])
+ array_ops.stack([
+ offset_height, after_padding_height, offset_width,
+ after_padding_width, 0, 0
+ ]), [3, 2])
padded = array_ops.pad(image, paddings)
padded_shape = [None if _is_tensor(i) else i
@@ -488,10 +488,9 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height,
'height must be >= target + offset.')
image = control_flow_ops.with_dependencies(assert_ops, image)
- cropped = array_ops.slice(
- image,
- array_ops.pack([offset_height, offset_width, 0]),
- array_ops.pack([target_height, target_width, -1]))
+ cropped = array_ops.slice(image,
+ array_ops.stack([offset_height, offset_width, 0]),
+ array_ops.stack([target_height, target_width, -1]))
cropped_shape = [None if _is_tensor(i) else i
for i in [target_height, target_width, depth]]
@@ -1229,7 +1228,7 @@ def decode_image(contents, channels=None, name=None):
JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF
images.
"""
- with ops.name_scope(name, 'decode_image') as scope:
+ with ops.name_scope(name, 'decode_image') as scope:
if channels not in (None, 0, 1, 3):
raise ValueError('channels must be in (None, 0, 1, 3)')
substr = string_ops.substr(contents, 0, 4)
@@ -1247,14 +1246,14 @@ def decode_image(contents, channels=None, name=None):
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
return gen_image_ops.decode_gif(contents)
-
+
def _png():
return gen_image_ops.decode_png(contents, channels)
-
+
def check_png():
is_png = math_ops.equal(substr, b'\211PNG', name='is_png')
return control_flow_ops.cond(is_png, _png, _gif, name='cond_png')
-
+
def _jpeg():
return gen_image_ops.decode_jpeg(contents, channels)
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index e568cff352..4dce46f29f 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -59,8 +59,8 @@ class RGBToHSVTest(test_util.TensorFlowTestCase):
split0 = array_ops.unpack(batch0)
split1 = list(map(image_ops.rgb_to_hsv, split0))
split2 = list(map(image_ops.hsv_to_rgb, split1))
- join1 = array_ops.pack(split1)
- join2 = array_ops.pack(split2)
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
# Verify that processing batch elements together is the same as separate
diff --git a/tensorflow/python/ops/metrics.py b/tensorflow/python/ops/metrics.py
index 04cd4ecbd4..938ae0d71b 100644
--- a/tensorflow/python/ops/metrics.py
+++ b/tensorflow/python/ops/metrics.py
@@ -444,7 +444,7 @@ def _confusion_matrix_at_thresholds(
num_predictions = array_ops.shape(predictions_2d)[0]
thresh_tiled = array_ops.tile(
array_ops.expand_dims(array_ops.constant(thresholds), [1]),
- array_ops.pack([1, num_predictions]))
+ array_ops.stack([1, num_predictions]))
# Tile the predictions after thresholding them across different thresholds.
pred_is_pos = math_ops.greater(
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 2bc05d6f07..d2a1619f46 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -542,7 +542,7 @@ def _TopKGrad(op, grad, _):
ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1)
# Flatten indices to 2D.
- ind_2d = array_ops.reshape(op.outputs[1], array_ops.pack([-1, ind_lastdim]))
+ ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1)
outerdim = array_ops.shape(ind_2d)[0]
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 31d9725ce9..650984195c 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -846,7 +846,7 @@ def _sum_rows(x):
# we use _sum_rows(x) in the nce_loss() computation since the loss
# is mostly used for training.
cols = array_ops.shape(x)[1]
- ones_shape = array_ops.pack([cols, 1])
+ ones_shape = array_ops.stack([cols, 1])
ones = array_ops.ones(ones_shape, x.dtype)
return array_ops.reshape(math_ops.matmul(x, ones), [-1])
@@ -942,7 +942,7 @@ def _compute_sampled_logits(weights,
# true_w shape is [batch_size * num_true, dim]
# true_b is a [batch_size * num_true] tensor
true_w = array_ops.slice(
- all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1]))
+ all_w, [0, 0], array_ops.stack([array_ops.shape(labels_flat)[0], -1]))
true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
# inputs shape is [batch_size, dim]
@@ -965,7 +965,7 @@ def _compute_sampled_logits(weights,
# sampled_w shape is [num_sampled, dim]
# sampled_b is a [num_sampled] float tensor
sampled_w = array_ops.slice(
- all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
+ all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
# inputs has shape [batch_size, dim]
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 7ccb4a7fa9..4f50e3e1bc 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -333,8 +333,8 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None, #
# convention as conv2d.
pad_extra_start = pad_extra_shape // 2
pad_extra_end = pad_extra_shape - pad_extra_start
- base_paddings = array_ops.pack([[pad_extra_start[i], pad_extra_end[i]]
- for i in range(num_spatial_dims)])
+ base_paddings = array_ops.stack([[pad_extra_start[i], pad_extra_end[i]]
+ for i in range(num_spatial_dims)])
elif padding == "VALID":
base_paddings = np.zeros([num_spatial_dims, 2], np.int32)
else:
diff --git a/tensorflow/python/ops/resources.py b/tensorflow/python/ops/resources.py
index 05369825dd..41fb8a74a9 100644
--- a/tensorflow/python/ops/resources.py
+++ b/tensorflow/python/ops/resources.py
@@ -89,8 +89,8 @@ def report_uninitialized_resources(resource_list=None,
# size being 0 as an indication of model ready.
return array_ops.constant([], dtype=dtypes.string)
# Get a 1-D boolean tensor listing whether each resource is initialized.
- variables_mask = math_ops.logical_not(array_ops.pack(
- [r.is_initialized for r in resource_list]))
+ variables_mask = math_ops.logical_not(
+ array_ops.stack([r.is_initialized for r in resource_list]))
# Get a 1-D string tensor containing all the resource names.
variable_names_tensor = array_ops.constant(
[s.handle.name for s in resource_list])
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index c448cbd315..cfa7d78e67 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -140,16 +140,17 @@ class RNNCell(object):
state_size_flat = nest.flatten(state_size)
zeros_flat = [
array_ops.zeros(
- array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
- dtype=dtype)
- for s in state_size_flat]
+ array_ops.stack(_state_size_with_prefix(
+ s, prefix=[batch_size])),
+ dtype=dtype) for s in state_size_flat
+ ]
for s, z in zip(state_size_flat, zeros_flat):
z.set_shape(_state_size_with_prefix(s, prefix=[None]))
zeros = nest.pack_sequence_as(structure=state_size,
flat_sequence=zeros_flat)
else:
zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
- zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
+ zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype)
zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))
return zeros
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index ea807b7d2b..dd00a52a32 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -855,7 +855,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
# Slice off the last dimension of indices, then tack on the ids
indices_columns_to_preserve = array_ops.slice(
- sp_ids.indices, [0, 0], array_ops.pack([-1, rank - 1]))
+ sp_ids.indices, [0, 0], array_ops.stack([-1, rank - 1]))
new_indices = array_ops.concat_v2(
[indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1)
@@ -863,7 +863,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
new_shape = array_ops.concat_v2([
array_ops.slice(sp_ids.dense_shape, [0],
array_ops.expand_dims(rank - 1, 0)),
- math_ops.cast(array_ops.pack([vocab_size]), dtypes.int64)
+ math_ops.cast(array_ops.stack([vocab_size]), dtypes.int64)
], 0)
result = sparse_tensor.SparseTensor(new_indices, new_values, new_shape)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 32545e4eb3..ec45329cb0 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1300,7 +1300,7 @@ def assert_variables_initialized(var_list=None):
if len(ranks) == 1:
return ranks[0]
else:
- return array_ops.pack(ranks)
+ return array_ops.stack(ranks)
def report_uninitialized_variables(var_list=None,
@@ -1334,8 +1334,9 @@ def report_uninitialized_variables(var_list=None,
return array_ops.constant([], dtype=dtypes.string)
else:
# Get a 1-D boolean tensor listing whether each variable is initialized.
- variables_mask = math_ops.logical_not(array_ops.pack(
- [state_ops.is_variable_initialized(v) for v in var_list]))
+ variables_mask = math_ops.logical_not(
+ array_ops.stack(
+ [state_ops.is_variable_initialized(v) for v in var_list]))
# Get a 1-D string tensor containing all the variable names.
variable_names_tensor = array_ops.constant([s.op.name for s in var_list])
# Return a 1-D tensor containing all the names of uninitialized variables.