diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-12-14 23:45:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-15 00:04:59 -0800 |
commit | c5dc750ba9fab7e7f1f05ee0e0cdb04ae96e0e32 (patch) | |
tree | 937edf17553f8d1f24abaf683dc83b10e7e730f4 | |
parent | 3bb102941e638617894facca6859b444154f8c2b (diff) |
Switch array_ops.pack/unpack to array_ops.stack/unstack. Also switch a few remaining references to tf.pack/unpack to tf.stack/unstack.
Change: 142108785
36 files changed, 94 insertions, 89 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 0b00d16e60..671aa6a513 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -549,9 +549,10 @@ def fill_lower_triangular(x, validate_args=False, name="fill_lower_triangular"): batch_ids = math_ops.range(m) # Assemble the tril_ids into batch,tril_id pairs. - idx = array_ops.pack([ - array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n*n]), - array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1])]) + idx = array_ops.stack([ + array_ops.tile(array_ops.expand_dims(batch_ids, 1), [1, n * n]), + array_ops.tile(array_ops.expand_dims(tril_ids(n), 0), [m, 1]) + ]) idx = array_ops.transpose(idx, [1, 2, 0]) # Gather up, reshape, and return. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index a0d61ef558..63d67fe67f 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -232,7 +232,7 @@ class Mixture(distribution.Distribution): cat_lp + d_lp for (cat_lp, d_lp) in zip(cat_log_probs, distribution_log_probs) ] - concat_log_probs = array_ops.pack(final_log_probs, 0) + concat_log_probs = array_ops.stack(final_log_probs, 0) log_sum_exp = math_ops.reduce_logsumexp(concat_log_probs, [0]) return log_sum_exp diff --git a/tensorflow/contrib/distributions/python/ops/mvn.py b/tensorflow/contrib/distributions/python/ops/mvn.py index 4341acbba1..0595ca89d8 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn.py +++ b/tensorflow/contrib/distributions/python/ops/mvn.py @@ -222,7 +222,7 @@ class _MultivariateNormalOperatorPD(distribution.Distribution): return self._cov.get_batch_shape() def _event_shape(self): - return array_ops.pack([self._cov.vector_space_dimension()]) + return array_ops.stack([self._cov.vector_space_dimension()]) def _get_event_shape(self): return self._cov.get_shape()[-1:] @@ -240,7 +240,7 @@ class _MultivariateNormalOperatorPD(distribution.Distribution): # Move the last dimension to the front perm = array_ops.concat_v2( - (array_ops.pack([array_ops.rank(correlated_samples) - 1]), + (array_ops.stack([array_ops.rank(correlated_samples) - 1]), math_ops.range(0, array_ops.rank(correlated_samples) - 1)), 0) # TODO(ebrevdo): Once we get a proper tensor contraction op, diff --git a/tensorflow/contrib/integrate/__init__.py b/tensorflow/contrib/integrate/__init__.py index e88d10c582..953dc6c55a 100644 --- a/tensorflow/contrib/integrate/__init__.py +++ b/tensorflow/contrib/integrate/__init__.py @@ -27,11 +27,11 @@ sigma = 10.0 beta = 8.0/3.0 def lorenz_equation(state, t): - x, y, z = tf.unpack(state) + x, y, z = tf.unstack(state) dx = sigma * (y - x) dy = x * (rho - z) - y dz = x * y - beta * z - return tf.pack([dx, dy, dz]) + return tf.stack([dx, dy, dz]) init_state = tf.constant([0, 2, 20], dtype=tf.float64) t = np.linspace(0, 50, num=5000) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index cb036bf05a..55d92fe9cf 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -214,8 +214,8 @@ class InterpolationTest(tf.test.TestCase): coeffs = odes._interp_fit( f(0.0), f(10.0), f(5.0), f_prime(0.0), f_prime(10.0), 10.0) times = np.linspace(0, 10, dtype=np.float32) - y_fit = tf.pack([odes._interp_evaluate(coeffs, 0.0, 10.0, t) - for t in times]) + y_fit = tf.stack( + [odes._interp_evaluate(coeffs, 0.0, 10.0, t) for t in times]) y_expected = f(times) with self.test_session() as sess: y_actual = sess.run(y_fit) diff --git a/tensorflow/contrib/layers/python/layers/embedding_ops.py b/tensorflow/contrib/layers/python/layers/embedding_ops.py index 490156806f..fd4eda4ec2 100644 --- a/tensorflow/contrib/layers/python/layers/embedding_ops.py +++ b/tensorflow/contrib/layers/python/layers/embedding_ops.py @@ -149,7 +149,7 @@ def safe_embedding_lookup_sparse(embedding_weights, # for use in Select. is_row_empty = array_ops.tile( array_ops.reshape(is_row_empty, [-1, 1]), - array_ops.pack([1, array_ops.shape(result)[1]])) + array_ops.stack([1, array_ops.shape(result)[1]])) result = array_ops.where(is_row_empty, array_ops.zeros_like(result), diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 4a414eb1d1..7547da1080 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -1688,8 +1688,8 @@ class _BucketizedColumn(_FeatureColumn, collections.namedtuple( i2 = array_ops.zeros([batch_size], dtype=dtypes.int32, name="zeros") bucket_indices = array_ops.reshape(input_tensor, [-1], name="reshape") - indices = math_ops.to_int64(array_ops.transpose(array_ops.pack((i1, i2)))) - shape = math_ops.to_int64(array_ops.pack([batch_size, dimension])) + indices = math_ops.to_int64(array_ops.transpose(array_ops.stack((i1, i2)))) + shape = math_ops.to_int64(array_ops.stack([batch_size, dimension])) sparse_id_values = sparse_tensor_py.SparseTensor( indices, bucket_indices, shape) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 47100d7d05..4d54886a94 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -2113,7 +2113,7 @@ def legacy_fully_connected(x, out_shape = array_ops.unpack(array_ops.shape(x)) out_shape[-1] = num_output_units - y = array_ops.reshape(y, array_ops.pack(out_shape)) + y = array_ops.reshape(y, array_ops.stack(out_shape)) static_shape = x.get_shape().as_list() static_shape[-1] = num_output_units diff --git a/tensorflow/contrib/layers/python/ops/sparse_ops.py b/tensorflow/contrib/layers/python/ops/sparse_ops.py index 325f5ac97b..f4d67b370a 100644 --- a/tensorflow/contrib/layers/python/ops/sparse_ops.py +++ b/tensorflow/contrib/layers/python/ops/sparse_ops.py @@ -73,7 +73,7 @@ def dense_to_sparse_tensor(dense_tensor, ignore_value=None): # Computes the correct flattened indices for 2d (or higher) tensors. if index_dims > 1: higher_dims = indices[:, :index_dims - 1] - shape_multipliers = array_ops.pack( + shape_multipliers = array_ops.stack( _multiplier_helper(array_ops.unpack(dense_shape)[1:])) offsets = math_ops.reduce_sum( math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1]) diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index 0582028b88..d2018bb219 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -654,7 +654,7 @@ def attention_decoder(decoder_inputs, outputs = [] prev = None - batch_attn_size = array_ops.pack([batch_size, attn_size]) + batch_attn_size = array_ops.stack([batch_size, attn_size]) attns = [ array_ops.zeros( batch_attn_size, dtype=dtype) for _ in xrange(num_heads) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py index ddc4fa9588..94299b451d 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py @@ -211,9 +211,10 @@ class LinearOperatorComposition(linear_operator.LinearOperator): # Don't check the matrix dimensions. That would add unnecessary Asserts to # the graph. Things will fail at runtime naturally if shapes are # incompatible. - matrix_shape = array_ops.pack( - [self.operators[0].range_dimension_dynamic(), - self.operators[-1].domain_dimension_dynamic()]) + matrix_shape = array_ops.stack([ + self.operators[0].range_dimension_dynamic(), + self.operators[-1].domain_dimension_dynamic() + ]) # Dummy Tensor of zeros. Will never be materialized. zeros = array_ops.zeros(shape=self.operators[0].batch_shape_dynamic()) diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 88036c87dc..9c5ffc293f 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -620,7 +620,7 @@ def _streaming_confusion_matrix_at_thresholds( num_predictions = array_ops.shape(predictions_2d)[0] thresh_tiled = array_ops.tile( array_ops.expand_dims(array_ops.constant(thresholds), [1]), - array_ops.pack([1, num_predictions])) + array_ops.stack([1, num_predictions])) # Tile the predictions after thresholding them across different thresholds. pred_is_pos = math_ops.greater( @@ -2605,7 +2605,7 @@ def streaming_concat(values, def reallocate(): next_size = _next_array_size(new_size) - next_shape = array_ops.pack([next_size] + fixed_shape) + next_shape = array_ops.stack([next_size] + fixed_shape) new_value = array_ops.zeros(next_shape, dtype=values.dtype) old_value = array.value() assign_op = state_ops.assign(array, new_value, validate_shape=False) diff --git a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py index 4306b68e7d..465607ff51 100644 --- a/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/fused_rnn_cell.py @@ -100,7 +100,7 @@ class FusedRNNCellAdaptor(FusedRNNCell): is_list = isinstance(inputs, list) if self._use_dynamic_rnn: if is_list: - inputs = array_ops.pack(inputs) + inputs = array_ops.stack(inputs) outputs, state = rnn.dynamic_rnn( self._cell, inputs, @@ -123,7 +123,7 @@ class FusedRNNCellAdaptor(FusedRNNCell): scope=scope) if not is_list: # Convert outputs back to tensor - outputs = array_ops.pack(outputs) + outputs = array_ops.stack(outputs) return outputs, state diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 4ad269ab4f..875d5e860f 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -206,7 +206,7 @@ def _block_lstm(seq_len_max, # pylint: disable=protected-access i, cs, f, o, ci, co, h = _lstm_ops_so.block_lstm( seq_len_max=seq_len_max, - x=array_ops.pack(x), + x=array_ops.stack(x), cs_prev=cs_prev, h_prev=h_prev, w=w, @@ -480,7 +480,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell): with vs.variable_scope(scope or "lstm_block_wrapper"): is_list = isinstance(inputs, list) if is_list: - inputs = array_ops.pack(inputs) + inputs = array_ops.stack(inputs) inputs_shape = inputs.get_shape().with_rank(3) if not inputs_shape[2]: raise ValueError("Expecting inputs_shape[2] to be set: %s" % @@ -498,7 +498,7 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell): raise ValueError( "Either initial_state or dtype needs to be specified") z = array_ops.zeros( - array_ops.pack([batch_size, self.num_units]), dtype=dtype) + array_ops.stack([batch_size, self.num_units]), dtype=dtype) initial_state = z, z else: if len(initial_state) != 2: diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index a2c71ae334..bf6f14e8f5 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -192,7 +192,7 @@ class Tensor(ItemHandler): if isinstance(shape_dim, sparse_tensor.SparseTensor): shape_dim = sparse_ops.sparse_tensor_to_dense(shape_dim) shape_dims.append(shape_dim) - shape = array_ops.reshape(array_ops.pack(shape_dims), [-1]) + shape = array_ops.reshape(array_ops.stack(shape_dims), [-1]) if isinstance(tensor, sparse_tensor.SparseTensor): if shape is not None: tensor = sparse_ops.sparse_reshape(tensor, shape) @@ -251,7 +251,7 @@ class SparseTensor(ItemHandler): rank = indices_shape[1] ids = math_ops.to_int64(indices.values) indices_columns_to_preserve = array_ops.slice( - indices.indices, [0, 0], array_ops.pack([-1, rank - 1])) + indices.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat_v2( [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1) diff --git a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py index ea9b9ce6de..59bc3fbc22 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py +++ b/tensorflow/contrib/tensor_forest/hybrid/python/hybrid_model.py @@ -67,7 +67,7 @@ class HybridModel(object): # results. if isinstance(layer, collections.Iterable): return math_ops.reduce_mean( - array_ops.pack([l.inference_graph(data) for l in layer]), 0) + array_ops.stack([l.inference_graph(data) for l in layer]), 0) # If this is a single layer, return its inference result. else: return layer.inference_graph(data) diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 5dff5f5f31..23873ba3be 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -416,7 +416,7 @@ class RandomForestGraphs(object): probabilities.append(self.trees[i].inference_graph( tree_data, data_spec, **inference_args)) with ops.device(self.device_assigner.get_device(0)): - all_predict = array_ops.pack(probabilities) + all_predict = array_ops.stack(probabilities) return math_ops.div( math_ops.reduce_sum(all_predict, 0), self.params.num_trees, name='probabilities') @@ -431,7 +431,7 @@ class RandomForestGraphs(object): for i in range(self.params.num_trees): with ops.device(self.device_assigner.get_device(i)): sizes.append(self.trees[i].size()) - return math_ops.reduce_mean(math_ops.to_float(array_ops.pack(sizes))) + return math_ops.reduce_mean(math_ops.to_float(array_ops.stack(sizes))) # pylint: disable=unused-argument def training_loss(self, features, labels, data_spec=None, @@ -452,7 +452,7 @@ class RandomForestGraphs(object): for i in range(self.params.num_trees): with ops.device(self.device_assigner.get_device(i)): impurities.append(self.trees[i].average_impurity()) - return math_ops.reduce_mean(array_ops.pack(impurities)) + return math_ops.reduce_mean(array_ops.stack(impurities)) def get_stats(self, session): tree_stats = [] diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index b978daed3e..83f71856f0 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1164,7 +1164,8 @@ class SequenceQueueingStateSaver(object): insert_initial_state_ops = dict( (name, self._barrier.insert_many( self._get_barrier_index("state", name), - array_ops.pack([current_keys[0]]), array_ops.pack([value]), + array_ops.stack([current_keys[0]]), + array_ops.stack([value]), name="BarrierInitialInsertState_%s" % name)) for (name, value) in self._uninitialized_states.items()) diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 711ea278c6..bcc4565b1a 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -1084,7 +1084,7 @@ class Conv2DTranspose(Conv2D): output_shape = (batch_size, out_height, out_width, self.filters) strides = (1, stride_h, stride_w, 1) - output_shape_tensor = array_ops.pack(output_shape) + output_shape_tensor = array_ops.stack(output_shape) outputs = nn.conv2d_transpose( inputs, self.kernel, diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index 1ccf83441e..5c74b869d9 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -137,7 +137,7 @@ class Dense(base._Layer): # pylint: disable=protected-access # Reshape the input to 2D. output_shape_tensors = array_ops.unpack(array_ops.shape(inputs)) output_shape_tensors[-1] = self.units - output_shape_tensor = array_ops.pack(output_shape_tensors) + output_shape_tensor = array_ops.stack(output_shape_tensors) inputs = array_ops.reshape(inputs, [-1, input_dim]) outputs = standard_ops.matmul(inputs, self.w) diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index f828490647..ad683730c7 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -40,7 +40,7 @@ def _PackGrad(op, grad): @ops.RegisterGradient("Unpack") def _UnpackGrad(op, *grads): """Gradient for unpack op.""" - return array_ops.pack(grads, axis=op.get_attr("axis")) + return array_ops.stack(grads, axis=op.get_attr("axis")) def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): @@ -114,9 +114,10 @@ def _ConcatGradHelper(op, grad, start_value_index, end_value_index, dim_index): # pylint: disable=protected-access if len(sizes) > 16: # extract the size of each input along the concat dimension - sizes = array_ops.squeeze(array_ops.slice(array_ops.pack(sizes, axis=1), - [concat_dim, 0], - [1, -1])) + sizes = array_ops.squeeze( + array_ops.slice( + array_ops.stack( + sizes, axis=1), [concat_dim, 0], [1, -1])) out_grads = array_ops.split(grad, sizes, concat_dim) else: offset = gen_array_ops._concat_offset(concat_dim, sizes) @@ -206,7 +207,7 @@ def _SliceGrad(op, grad): input_rank = array_ops.rank(input_vec) slice_size = array_ops.shape(op.outputs[0]) - shape = array_ops.pack([input_rank, 1]) + shape = array_ops.stack([input_rank, 1]) before_pad = array_ops.reshape(begin_vec, shape) after_pad = array_ops.reshape( array_ops.shape(input_vec) - slice_size - begin_vec, shape) @@ -435,8 +436,8 @@ def _TileGrad(op, grad): # multiples = [2, 3, 4] # split_shape = [2, 20, 3, 30, 4, 40] # axes = [0, 2, 4] - split_shape = array_ops.reshape(array_ops.transpose( - array_ops.pack([op.inputs[1], input_shape])), [-1]) + split_shape = array_ops.reshape( + array_ops.transpose(array_ops.stack([op.inputs[1], input_shape])), [-1]) axes = math_ops.range(0, array_ops.size(split_shape), 2) input_grad = math_ops.reduce_sum(array_ops.reshape(grad, split_shape), axes) # Fix shape inference @@ -456,7 +457,7 @@ def _PadGrad(op, grad): a = op.inputs[1] # [Rank(x), 2] # Takes a slice of a. The 1st column. [Rank(x), 1]. pad_before = array_ops.slice(a, [0, 0], - array_ops.pack([array_ops.rank(x), 1])) + array_ops.stack([array_ops.rank(x), 1])) # Make it a 1-D tensor. begin = array_ops.reshape(pad_before, [-1]) sizes = array_ops.shape(x) diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index e0e2fb1304..bda7212c8a 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -142,7 +142,7 @@ def global_norm(t_list, name=None): with ops.colocate_with(v): half_squared_norms.append(gen_nn_ops.l2_loss(v)) - half_squared_norm = math_ops.reduce_sum(array_ops.pack(half_squared_norms)) + half_squared_norm = math_ops.reduce_sum(array_ops.stack(half_squared_norms)) norm = math_ops.sqrt( half_squared_norm * diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index 576b78b15f..a45397b3af 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -152,8 +152,8 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, predictions.get_shape().assert_is_compatible_with(weights.get_shape()) weights = math_ops.cast(weights, dtype) - shape = array_ops.pack([num_classes, num_classes]) - indices = array_ops.transpose(array_ops.pack([predictions, labels])) + shape = array_ops.stack([num_classes, num_classes]) + indices = array_ops.transpose(array_ops.stack([predictions, labels])) values = (array_ops.ones_like(predictions, dtype) if weights is None else weights) cm_sparse = sparse_tensor.SparseTensor( diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py index aae65b194b..95bba3efc5 100644 --- a/tensorflow/python/ops/embedding_ops.py +++ b/tensorflow/python/ops/embedding_ops.py @@ -137,7 +137,7 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None, with ops.colocate_with(params[p]): dim_0_sizes.append(array_ops.shape(params[p])[0]) num_total_ids = math_ops.reduce_sum( - math_ops.cast(array_ops.pack(dim_0_sizes), flat_ids.dtype)) + math_ops.cast(array_ops.stack(dim_0_sizes), flat_ids.dtype)) ids_per_partition = num_total_ids // np extras = num_total_ids % np diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 4075da9900..4010b95372 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -891,5 +891,5 @@ def hessians(ys, xs, name="hessians", colocate_gradients_with_ops=False, # Compute the partial derivatives with respect to each element of the list _hess = [gradients(_gradient, x, **kwargs)[0] for _gradient in _gradients] # Pack the list into a matrix and add to the list of hessians - hessians.append(array_ops.pack(_hess, name=name)) + hessians.append(array_ops.stack(_hess, name=name)) return hessians diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 8b4fbeb7a9..91b147a3c1 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -86,7 +86,7 @@ class GradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default() as g: t1 = constant(1.0) t2 = constant(2.0) - t3 = array_ops.pack([t1, t2]) + t3 = array_ops.stack([t1, t2]) # Full graph self._assertOpListEqual([t3.op, t2.op, t1.op], _OpsBetween(g, [t3.op], [t1.op, t2.op])) @@ -98,10 +98,10 @@ class GradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default() as g: t1 = constant(1.0) t2 = constant(2.0) - _ = array_ops.pack([t1, t2]) + _ = array_ops.stack([t1, t2]) t4 = constant(1.0) t5 = constant(2.0) - t6 = array_ops.pack([t4, t5]) + t6 = array_ops.stack([t4, t5]) # Elements of to_ops are always listed. self._assertOpListEqual([t6.op], _OpsBetween(g, [t6.op], [t1.op])) @@ -109,7 +109,7 @@ class GradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default() as g: t1 = constant(1.0) t2 = constant(2.0) - t3 = array_ops.pack([t1, t2]) + t3 = array_ops.stack([t1, t2]) t4 = constant([1.0]) t5 = array_ops.concat_v2([t4, t3], 0) t6 = constant([2.0]) @@ -121,7 +121,7 @@ class GradientsTest(test_util.TensorFlowTestCase): with ops.Graph().as_default() as g: t1 = constant(1.0) t2 = constant(2.0) - t3 = array_ops.pack([t1, t2]) + t3 = array_ops.stack([t1, t2]) t4 = array_ops.concat_v2([t3, t3, t3], 0) t5 = constant([1.0]) t6 = array_ops.concat_v2([t4, t5], 0) @@ -491,8 +491,8 @@ class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): numpy_list.append(np_val) dense_list.append(c) sparse_list.append(c_sparse) - packed_dense = array_ops.pack(dense_list) - packed_sparse = array_ops.pack(sparse_list) + packed_dense = array_ops.stack(dense_list) + packed_sparse = array_ops.stack(sparse_list) self.assertAllClose(packed_dense.eval(), packed_sparse.eval()) def testInt64Indices(self): diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 15694d4b3f..209cf5aa63 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -372,8 +372,8 @@ def central_crop(image, central_fraction): bbox_h_size = img_shape[0] - bbox_h_start * 2 bbox_w_size = img_shape[1] - bbox_w_start * 2 - bbox_begin = array_ops.pack([bbox_h_start, bbox_w_start, 0]) - bbox_size = array_ops.pack([bbox_h_size, bbox_w_size, -1]) + bbox_begin = array_ops.stack([bbox_h_start, bbox_w_start, 0]) + bbox_size = array_ops.stack([bbox_h_size, bbox_w_size, -1]) image = array_ops.slice(image, bbox_begin, bbox_size) # The first two dimensions are dynamic and unknown. @@ -428,10 +428,10 @@ def pad_to_bounding_box(image, offset_height, offset_width, target_height, # Do not pad on the depth dimensions. paddings = array_ops.reshape( - array_ops.pack([offset_height, after_padding_height, - offset_width, after_padding_width, - 0, 0]), - [3, 2]) + array_ops.stack([ + offset_height, after_padding_height, offset_width, + after_padding_width, 0, 0 + ]), [3, 2]) padded = array_ops.pad(image, paddings) padded_shape = [None if _is_tensor(i) else i @@ -488,10 +488,9 @@ def crop_to_bounding_box(image, offset_height, offset_width, target_height, 'height must be >= target + offset.') image = control_flow_ops.with_dependencies(assert_ops, image) - cropped = array_ops.slice( - image, - array_ops.pack([offset_height, offset_width, 0]), - array_ops.pack([target_height, target_width, -1])) + cropped = array_ops.slice(image, + array_ops.stack([offset_height, offset_width, 0]), + array_ops.stack([target_height, target_width, -1])) cropped_shape = [None if _is_tensor(i) else i for i in [target_height, target_width, depth]] @@ -1229,7 +1228,7 @@ def decode_image(contents, channels=None, name=None): JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF images. """ - with ops.name_scope(name, 'decode_image') as scope: + with ops.name_scope(name, 'decode_image') as scope: if channels not in (None, 0, 1, 3): raise ValueError('channels must be in (None, 0, 1, 3)') substr = string_ops.substr(contents, 0, 4) @@ -1247,14 +1246,14 @@ def decode_image(contents, channels=None, name=None): assert_channels = control_flow_ops.Assert(good_channels, [channels_msg]) with ops.control_dependencies([assert_decode, assert_channels]): return gen_image_ops.decode_gif(contents) - + def _png(): return gen_image_ops.decode_png(contents, channels) - + def check_png(): is_png = math_ops.equal(substr, b'\211PNG', name='is_png') return control_flow_ops.cond(is_png, _png, _gif, name='cond_png') - + def _jpeg(): return gen_image_ops.decode_jpeg(contents, channels) diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index e568cff352..4dce46f29f 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -59,8 +59,8 @@ class RGBToHSVTest(test_util.TensorFlowTestCase): split0 = array_ops.unpack(batch0) split1 = list(map(image_ops.rgb_to_hsv, split0)) split2 = list(map(image_ops.hsv_to_rgb, split1)) - join1 = array_ops.pack(split1) - join2 = array_ops.pack(split2) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2]) # Verify that processing batch elements together is the same as separate diff --git a/tensorflow/python/ops/metrics.py b/tensorflow/python/ops/metrics.py index 04cd4ecbd4..938ae0d71b 100644 --- a/tensorflow/python/ops/metrics.py +++ b/tensorflow/python/ops/metrics.py @@ -444,7 +444,7 @@ def _confusion_matrix_at_thresholds( num_predictions = array_ops.shape(predictions_2d)[0] thresh_tiled = array_ops.tile( array_ops.expand_dims(array_ops.constant(thresholds), [1]), - array_ops.pack([1, num_predictions])) + array_ops.stack([1, num_predictions])) # Tile the predictions after thresholding them across different thresholds. pred_is_pos = math_ops.greater( diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index 2bc05d6f07..d2a1619f46 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -542,7 +542,7 @@ def _TopKGrad(op, grad, _): ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1) # Flatten indices to 2D. - ind_2d = array_ops.reshape(op.outputs[1], array_ops.pack([-1, ind_lastdim])) + ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1) outerdim = array_ops.shape(ind_2d)[0] diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 31d9725ce9..650984195c 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -846,7 +846,7 @@ def _sum_rows(x): # we use _sum_rows(x) in the nce_loss() computation since the loss # is mostly used for training. cols = array_ops.shape(x)[1] - ones_shape = array_ops.pack([cols, 1]) + ones_shape = array_ops.stack([cols, 1]) ones = array_ops.ones(ones_shape, x.dtype) return array_ops.reshape(math_ops.matmul(x, ones), [-1]) @@ -942,7 +942,7 @@ def _compute_sampled_logits(weights, # true_w shape is [batch_size * num_true, dim] # true_b is a [batch_size * num_true] tensor true_w = array_ops.slice( - all_w, [0, 0], array_ops.pack([array_ops.shape(labels_flat)[0], -1])) + all_w, [0, 0], array_ops.stack([array_ops.shape(labels_flat)[0], -1])) true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat)) # inputs shape is [batch_size, dim] @@ -965,7 +965,7 @@ def _compute_sampled_logits(weights, # sampled_w shape is [num_sampled, dim] # sampled_b is a [num_sampled] float tensor sampled_w = array_ops.slice( - all_w, array_ops.pack([array_ops.shape(labels_flat)[0], 0]), [-1, -1]) + all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1]) sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1]) # inputs has shape [batch_size, dim] diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 7ccb4a7fa9..4f50e3e1bc 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -333,8 +333,8 @@ def with_space_to_batch(input, dilation_rate, padding, op, filter_shape=None, # # convention as conv2d. pad_extra_start = pad_extra_shape // 2 pad_extra_end = pad_extra_shape - pad_extra_start - base_paddings = array_ops.pack([[pad_extra_start[i], pad_extra_end[i]] - for i in range(num_spatial_dims)]) + base_paddings = array_ops.stack([[pad_extra_start[i], pad_extra_end[i]] + for i in range(num_spatial_dims)]) elif padding == "VALID": base_paddings = np.zeros([num_spatial_dims, 2], np.int32) else: diff --git a/tensorflow/python/ops/resources.py b/tensorflow/python/ops/resources.py index 05369825dd..41fb8a74a9 100644 --- a/tensorflow/python/ops/resources.py +++ b/tensorflow/python/ops/resources.py @@ -89,8 +89,8 @@ def report_uninitialized_resources(resource_list=None, # size being 0 as an indication of model ready. return array_ops.constant([], dtype=dtypes.string) # Get a 1-D boolean tensor listing whether each resource is initialized. - variables_mask = math_ops.logical_not(array_ops.pack( - [r.is_initialized for r in resource_list])) + variables_mask = math_ops.logical_not( + array_ops.stack([r.is_initialized for r in resource_list])) # Get a 1-D string tensor containing all the resource names. variable_names_tensor = array_ops.constant( [s.handle.name for s in resource_list]) diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c448cbd315..cfa7d78e67 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -140,16 +140,17 @@ class RNNCell(object): state_size_flat = nest.flatten(state_size) zeros_flat = [ array_ops.zeros( - array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])), - dtype=dtype) - for s in state_size_flat] + array_ops.stack(_state_size_with_prefix( + s, prefix=[batch_size])), + dtype=dtype) for s in state_size_flat + ] for s, z in zip(state_size_flat, zeros_flat): z.set_shape(_state_size_with_prefix(s, prefix=[None])) zeros = nest.pack_sequence_as(structure=state_size, flat_sequence=zeros_flat) else: zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size]) - zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype) + zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype) zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None])) return zeros diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py index ea807b7d2b..dd00a52a32 100644 --- a/tensorflow/python/ops/sparse_ops.py +++ b/tensorflow/python/ops/sparse_ops.py @@ -855,7 +855,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None, # Slice off the last dimension of indices, then tack on the ids indices_columns_to_preserve = array_ops.slice( - sp_ids.indices, [0, 0], array_ops.pack([-1, rank - 1])) + sp_ids.indices, [0, 0], array_ops.stack([-1, rank - 1])) new_indices = array_ops.concat_v2( [indices_columns_to_preserve, array_ops.reshape(ids, [-1, 1])], 1) @@ -863,7 +863,7 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None, new_shape = array_ops.concat_v2([ array_ops.slice(sp_ids.dense_shape, [0], array_ops.expand_dims(rank - 1, 0)), - math_ops.cast(array_ops.pack([vocab_size]), dtypes.int64) + math_ops.cast(array_ops.stack([vocab_size]), dtypes.int64) ], 0) result = sparse_tensor.SparseTensor(new_indices, new_values, new_shape) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 32545e4eb3..ec45329cb0 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1300,7 +1300,7 @@ def assert_variables_initialized(var_list=None): if len(ranks) == 1: return ranks[0] else: - return array_ops.pack(ranks) + return array_ops.stack(ranks) def report_uninitialized_variables(var_list=None, @@ -1334,8 +1334,9 @@ def report_uninitialized_variables(var_list=None, return array_ops.constant([], dtype=dtypes.string) else: # Get a 1-D boolean tensor listing whether each variable is initialized. - variables_mask = math_ops.logical_not(array_ops.pack( - [state_ops.is_variable_initialized(v) for v in var_list])) + variables_mask = math_ops.logical_not( + array_ops.stack( + [state_ops.is_variable_initialized(v) for v in var_list])) # Get a 1-D string tensor containing all the variable names. variable_names_tensor = array_ops.constant([s.op.name for s in var_list]) # Return a 1-D tensor containing all the names of uninitialized variables. |