aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py3
-rw-r--r--tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py35
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py23
-rw-r--r--tensorflow/contrib/tensor_forest/python/tensor_forest.py3
-rw-r--r--tensorflow/python/kernel_tests/split_op_test.py76
-rw-r--r--tensorflow/python/ops/array_ops.py61
-rw-r--r--tensorflow/python/ops/gradients_test.py2
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py13
-rw-r--r--tensorflow/python/ops/split_benchmark.py2
9 files changed, 73 insertions, 145 deletions
diff --git a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
index 6247cad380..aef4ce3fdb 100644
--- a/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
+++ b/tensorflow/contrib/grid_rnn/python/ops/grid_rnn_cell.py
@@ -170,7 +170,8 @@ class GridRNNCell(rnn.RNNCell):
# project input
if inputs is not None and sum(inputs.get_shape().as_list()) > 0 and len(
conf.inputs) > 0:
- input_splits = array_ops.split(1, len(conf.inputs), inputs)
+ input_splits = array_ops.split(
+ value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
input_sz = input_splits[0].get_shape().as_list()[1]
for i, j in enumerate(conf.inputs):
diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
index 644347f0b5..9edb00e7b0 100644
--- a/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
+++ b/tensorflow/contrib/linear_optimizer/python/sdca_optimizer.py
@@ -89,11 +89,12 @@ class SDCAOptimizer(object):
# very sparse features with weights and not weights.
return SparseFeatureColumn(
array_ops.reshape(
- array_ops.split(1, 2, sparse_indices)[0], [-1]),
+ array_ops.split(
+ value=sparse_indices, num_or_size_splits=2, axis=1)[0], [-1]),
array_ops.reshape(
- array_ops.split(1, 2, sparse_indices)[1], [-1]),
- array_ops.reshape(
- math_ops.to_float(sparse_values), [-1]))
+ array_ops.split(
+ value=sparse_indices, num_or_size_splits=2, axis=1)[1], [-1]),
+ array_ops.reshape(math_ops.to_float(sparse_values), [-1]))
def _training_examples_and_variables():
"""Returns dictionaries for training examples and variables."""
@@ -135,19 +136,27 @@ class SDCAOptimizer(object):
columns_to_variables[column][0])
elif isinstance(column, (layers.feature_column._CrossedColumn,
layers.feature_column._SparseColumn)):
- sparse_features.append(SparseFeatureColumn(
- array_ops.reshape(
- array_ops.split(1, 2, transformed_tensor.indices)[0], [-1]),
- array_ops.reshape(transformed_tensor.values, [-1]), None))
+ sparse_features.append(
+ SparseFeatureColumn(
+ array_ops.reshape(
+ array_ops.split(
+ value=transformed_tensor.indices,
+ num_or_size_splits=2,
+ axis=1)[0], [-1]),
+ array_ops.reshape(transformed_tensor.values, [-1]),
+ None))
sparse_feature_weights.append(columns_to_variables[column][0])
elif isinstance(column, layers.feature_column._WeightedSparseColumn):
id_tensor = column.id_tensor(transformed_tensor)
weight_tensor = column.weight_tensor(transformed_tensor)
- sparse_feature_with_values.append(SparseFeatureColumn(
- array_ops.reshape(
- array_ops.split(1, 2, id_tensor.indices)[0], [-1]),
- array_ops.reshape(id_tensor.values, [-1]), array_ops.reshape(
- weight_tensor.values, [-1])))
+ sparse_feature_with_values.append(
+ SparseFeatureColumn(
+ array_ops.reshape(
+ array_ops.split(
+ value=id_tensor.indices, num_or_size_splits=2, axis=1)
+ [0], [-1]),
+ array_ops.reshape(id_tensor.values, [-1]),
+ array_ops.reshape(weight_tensor.values, [-1])))
sparse_feature_with_values_weights.append(
columns_to_variables[column][0])
else:
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 6f6f19b41a..761fabba8d 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -216,7 +216,7 @@ class CoupledInputForgetGateLSTMCell(rnn_cell.RNNCell):
# j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat_v2([inputs, m_prev], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
- j, f, o = array_ops.split(1, 3, lstm_matrix)
+ j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
# Diagonal connections
if self._use_peepholes:
@@ -363,7 +363,8 @@ class TimeFreqLSTMCell(rnn_cell.RNNCell):
cell_inputs = array_ops.concat_v2(
[freq_inputs[fq], m_prev, m_prev_freq], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
- i, j, f, o = array_ops.split(1, 4, lstm_matrix)
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
@@ -670,12 +671,12 @@ class GridLSTMCell(rnn_cell.RNNCell):
lstm_matrix_freq = nn_ops.bias_add(math_ops.matmul(cell_inputs,
concat_w_f), b_f)
if self._couple_input_forget_gates:
- i_freq, j_freq, o_freq = array_ops.split(1, num_gates,
- lstm_matrix_freq)
+ i_freq, j_freq, o_freq = array_ops.split(
+ value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
f_freq = None
else:
- i_freq, j_freq, f_freq, o_freq = array_ops.split(1, num_gates,
- lstm_matrix_freq)
+ i_freq, j_freq, f_freq, o_freq = array_ops.split(
+ value=lstm_matrix_freq, num_or_size_splits=num_gates, axis=1)
# T-LSTM
if self._share_time_frequency_weights:
i_time = i_freq
@@ -686,12 +687,12 @@ class GridLSTMCell(rnn_cell.RNNCell):
lstm_matrix_time = nn_ops.bias_add(math_ops.matmul(cell_inputs,
concat_w_t), b_t)
if self._couple_input_forget_gates:
- i_time, j_time, o_time = array_ops.split(1, num_gates,
- lstm_matrix_time)
+ i_time, j_time, o_time = array_ops.split(
+ value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
f_time = None
else:
- i_time, j_time, f_time, o_time = array_ops.split(1, 4,
- lstm_matrix_time)
+ i_time, j_time, f_time, o_time = array_ops.split(
+ value=lstm_matrix_time, num_or_size_splits=num_gates, axis=1)
# F-LSTM c_freq
# input gate activations
@@ -1229,7 +1230,7 @@ class LayerNormBasicLSTMCell(rnn_cell.RNNCell):
args = array_ops.concat_v2([inputs, h], 1)
concat = self._linear(args)
- i, j, f, o = array_ops.split(1, 4, concat)
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
index 84fb8dab61..5dff5f5f31 100644
--- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py
+++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py
@@ -334,7 +334,8 @@ class RandomForestGraphs(object):
]
def _bag_features(self, tree_num, input_data):
- split_data = array_ops.split(1, self.params.num_features, input_data)
+ split_data = array_ops.split(
+ value=input_data, num_or_size_splits=self.params.num_features, axis=1)
return array_ops.concat_v2(
[split_data[ind] for ind in self.params.bagged_features[tree_num]], 1)
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 60a1f4bcf3..a76acea4cd 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -22,7 +22,7 @@ import numpy as np
import tensorflow as tf
-class SplitVOpTest(tf.test.TestCase):
+class SplitOpTest(tf.test.TestCase):
def testExplicitNum(self):
size_splits = tf.placeholder(dtype=tf.int32, shape=[None])
@@ -31,11 +31,11 @@ class SplitVOpTest(tf.test.TestCase):
with self.test_session(use_gpu=False) as sess:
with self.assertRaises(ValueError) as context:
- sess.run(tf.split_v(value, size_splits), {size_splits: [2, 2, 6]})
+ sess.run(tf.split(value, size_splits), {size_splits: [2, 2, 6]})
self.assertTrue("Cannot infer num from shape" in str(context.exception))
- result = sess.run(tf.split_v(value, size_splits, num=3),
+ result = sess.run(tf.split(value, size_splits, num=3),
{size_splits: [2, 2, 6]})
self.assertAllEqual(result[0], value[0:2])
@@ -49,12 +49,12 @@ class SplitVOpTest(tf.test.TestCase):
value = np.random.rand(11, 11)
with self.test_session(use_gpu=False) as sess:
- result = sess.run(tf.split_v(value, [a, b]))
+ result = sess.run(tf.split(value, [a, b]))
self.assertAllEqual(result[0], value[0:5, :])
self.assertAllEqual(result[1], value[5:, :])
- def _RunAndVerify(self, use_gpu, large_num_splits=False):
+ def _RunAndVerifyVariable(self, use_gpu, large_num_splits=False):
# Random dims of rank 5
shape = np.random.randint(1, 5, size=5)
split_dim = np.random.randint(0, 5)
@@ -66,7 +66,7 @@ class SplitVOpTest(tf.test.TestCase):
shape[split_dim] = np.sum(size_splits)
inp = np.random.rand(*shape).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
- result = sess.run(tf.split_v(inp, size_splits, split_dim))
+ result = sess.run(tf.split(inp, size_splits, split_dim))
slices = [slice(0, x) for x in shape]
offset = 0
for i in range(num_split):
@@ -74,53 +74,25 @@ class SplitVOpTest(tf.test.TestCase):
offset += size_splits[i]
self.assertAllEqual(result[i], inp[slices])
- def _RunAndVerifyScalar(self, use_gpu, large_num_splits=False):
- shape = np.random.randint(0, 5, size=5)
- split_dim = np.random.randint(0, 5)
- if large_num_splits:
- num_split = np.random.randint(16, 25)
- else:
- num_split = np.random.randint(2, 8)
- shape[split_dim] = np.random.randint(2, 5) * num_split
- inp = np.random.rand(*shape).astype("f")
- with self.test_session(use_gpu=use_gpu) as sess:
- result = sess.run(tf.split_v(inp, num_split, split_dim))
- slices = [slice(0, x) for x in shape]
- offset = 0
- length = shape[split_dim] // num_split
- for i in range(num_split):
- slices[split_dim] = slice(offset, offset + length)
- offset += length
- self.assertAllEqual(result[i], inp[slices])
-
- def testRandom(self):
- for _ in range(5):
- self._RunAndVerify(use_gpu=False)
- self._RunAndVerify(use_gpu=True)
- self._RunAndVerify(use_gpu=True, large_num_splits=True)
- self._RunAndVerifyScalar(use_gpu=False)
- self._RunAndVerifyScalar(use_gpu=True)
- self._RunAndVerifyScalar(use_gpu=True, large_num_splits=True)
-
- def _testSpecialCases(self, use_gpu):
+ def _testSpecialCasesVariable(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
- result = sess.run(tf.split_v(inp, [4], 0))
+ result = sess.run(tf.split(inp, [4], 0))
self.assertAllEqual(result[0], inp)
- result = sess.run(tf.split_v(inp, [-1, 3], 0))
+ result = sess.run(tf.split(inp, [-1, 3], 0))
self.assertAllEqual(result[0], inp[0:1, :])
self.assertAllEqual(result[1], inp[1:4, :])
- def _testHugeNumberOfTensors(self, use_gpu):
+ def _testHugeNumberOfTensorsVariable(self, use_gpu):
num_split = 10000
size_splits = np.random.randint(1, 3, num_split)
shape = [3, np.sum(size_splits)]
split_dim = 1
inp = np.random.rand(*shape).astype("f")
with self.test_session(use_gpu=use_gpu) as sess:
- result = sess.run(tf.split_v(inp, size_splits, split_dim))
+ result = sess.run(tf.split(inp, size_splits, split_dim))
slices = [slice(0, x) for x in shape]
offset = 0
for i in range(num_split):
@@ -128,17 +100,17 @@ class SplitVOpTest(tf.test.TestCase):
offset += size_splits[i]
self.assertAllEqual(result[i], inp[slices])
- def testSpecialCases(self):
- self._testSpecialCases(False)
- self._testSpecialCases(True)
- self._testHugeNumberOfTensors(False)
- self._testHugeNumberOfTensors(True)
+ def testSpecialCasesVariable(self):
+ self._testSpecialCasesVariable(False)
+ self._testSpecialCasesVariable(True)
+ self._testHugeNumberOfTensorsVariable(False)
+ self._testHugeNumberOfTensorsVariable(True)
- def _testGradientsSimple(self, use_gpu):
+ def _testGradientsSimpleVariable(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
with self.test_session(use_gpu=use_gpu):
inp_tensor = tf.convert_to_tensor(inp)
- s = tf.split_v(inp_tensor, [1, 4], 1)
+ s = tf.split(inp_tensor, [1, 4], 1)
inp_grads = [
np.random.rand(4, 1).astype("f"), np.random.rand(4, 3).astype("f")
]
@@ -149,13 +121,6 @@ class SplitVOpTest(tf.test.TestCase):
self.assertAllEqual(result[:, 0:1], inp_grads[0])
self.assertAllEqual(result[:, 1:4], inp_grads[1])
- def testGradientsAll(self):
- self._testGradientsSimple(use_gpu=False)
- self._testGradientsSimple(use_gpu=True)
-
-
-class SplitOpTest(tf.test.TestCase):
-
def _compare(self, x, dim, num, use_gpu):
np_ans = np.split(x, num, dim)
with self.test_session(use_gpu=use_gpu) as sess:
@@ -244,6 +209,9 @@ class SplitOpTest(tf.test.TestCase):
self._RunAndVerify(use_gpu=False)
self._RunAndVerify(use_gpu=True)
self._RunAndVerify(use_gpu=True, large_num_splits=True)
+ self._RunAndVerifyVariable(use_gpu=False)
+ self._RunAndVerifyVariable(use_gpu=True)
+ self._RunAndVerifyVariable(use_gpu=True, large_num_splits=True)
def _testGradientsSimple(self, use_gpu):
inp = np.random.rand(4, 4).astype("f")
@@ -260,6 +228,8 @@ class SplitOpTest(tf.test.TestCase):
def testGradientsAll(self):
self._testGradientsSimple(use_gpu=False)
self._testGradientsSimple(use_gpu=True)
+ self._testGradientsSimpleVariable(use_gpu=False)
+ self._testGradientsSimpleVariable(use_gpu=True)
def testShapeFunctionEdgeCases(self):
# split_dim greater than rank of input.
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 02bea9aa57..d78d7817f1 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -50,7 +50,6 @@ or join multiple tensors together.
@@slice
@@strided_slice
@@split
-@@split_v
@@tile
@@pad
@@concat
@@ -1226,63 +1225,7 @@ def sparse_mask(a, mask_indices, name=None):
return ops.IndexedSlices(out_values, out_indices, a.dense_shape)
-def split(axis=None,
- num_or_size_splits=None,
- value=None,
- name="split",
- split_dim=None):
- """DEPRECATED: use split_v; split_v rename to split happening soon.
-
- Splits `value` along dimension `axis` into `num_or_size_splits` smaller
- tensors. Requires that `num_or_size_splits` evenly divide `value.shape[axis]`.
-
- For example:
-
- ```python
- # 'value' is a tensor with shape [5, 30]
- # Split 'value' into 3 tensors along dimension 1
- split0, split1, split2 = tf.split(value=value, num_or_size_splits=3, axis=1)
- tf.shape(split0) ==> [5, 10]
- ```
-
- Note: If you are splitting along an axis by the length of that axis, consider
- using unpack, e.g.
-
- ```python
- num_items = t.get_shape()[axis].value
- [tf.squeeze(s, [axis]) for s in
- tf.split(value=t, num_or_size_splits=num_items, axis=axis)]
- ```
-
- can be rewritten as
-
- ```python
- tf.unpack(t, axis=axis)
- ```
-
- Args:
- axis: A 0-D `int32` `Tensor`. The dimension along which to split.
- Must be in the range `[0, rank(value))`.
- num_or_size_splits: A Python integer. The number of ways to split. Has a
- different meaning in split_v (see docs).
- value: The `Tensor` to split.
- name: A name for the operation (optional).
- split_dim: The old (deprecated) name for axis.
-
- Returns:
- `num_or_size_splits` `Tensor` objects resulting from splitting `value`.
- """
- axis = deprecation.deprecated_argument_lookup("axis", axis, "split_dim",
- split_dim)
- return gen_array_ops._split(
- split_dim=axis, num_split=num_or_size_splits, value=value, name=name)
-
-
-def split_v(value=None,
- num_or_size_splits=None,
- axis=0,
- num=None,
- name="split_v"):
+def split(value, num_or_size_splits, axis=0, num=None, name="split"):
"""Splits a tensor into sub tensors.
If `num_or_size_splits` is a scalar, `num_split`, then splits `value` along
@@ -1298,7 +1241,7 @@ def split_v(value=None,
```python
# 'value' is a tensor with shape [5, 30]
# Split 'value' into 3 tensors with sizes [4, 15, 11] along dimension 1
- split0, split1, split2 = tf.split_v(value, [4, 15, 11], 1)
+ split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) ==> [5, 4]
tf.shape(split1) ==> [5, 15]
tf.shape(split2) ==> [5, 11]
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 7272b1729a..8b4fbeb7a9 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -151,7 +151,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
w = constant(1.0, shape=[2, 2])
x = constant(1.0, shape=[2, 2])
wx = math_ops.matmul(w, x)
- split_wx = array_ops.split(0, 2, wx)
+ split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
c = math_ops.reduce_sum(split_wx[1])
gw = gradients.gradients(c, [w])[0]
self.assertEquals("MatMul", gw.op.type)
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 7e0abf854d..c448cbd315 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -203,8 +203,10 @@ class GRUCell(RNNCell):
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
r, u = array_ops.split(
- 1, 2, _linear([inputs, state], 2 * self._num_units, True, 1.0,
- scope=scope))
+ value=_linear(
+ [inputs, state], 2 * self._num_units, True, 1.0, scope=scope),
+ num_or_size_splits=2,
+ axis=1)
r, u = sigmoid(r), sigmoid(u)
with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state],
@@ -288,11 +290,11 @@ class BasicLSTMCell(RNNCell):
if self._state_is_tuple:
c, h = state
else:
- c, h = array_ops.split(1, 2, state)
+ c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
concat = _linear([inputs, h], 4 * self._num_units, True, scope=scope)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
- i, j, f, o = array_ops.split(1, 4, concat)
+ i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
@@ -449,7 +451,8 @@ class LSTMCell(RNNCell):
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
scope=scope)
- i, j, f, o = array_ops.split(1, 4, lstm_matrix)
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
# Diagonal connections
if self._use_peepholes:
diff --git a/tensorflow/python/ops/split_benchmark.py b/tensorflow/python/ops/split_benchmark.py
index d0647ecc90..5da0eac4b7 100644
--- a/tensorflow/python/ops/split_benchmark.py
+++ b/tensorflow/python/ops/split_benchmark.py
@@ -41,7 +41,7 @@ def build_graph(device, input_shape, output_sizes, axis):
outputs = []
for _ in range(100):
- outputs.extend(tf.split_v(inp, output_sizes, axis))
+ outputs.extend(tf.split(inp, output_sizes, axis))
return tf.group(*outputs)