aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/lstm_ops.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-06 18:25:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-06 18:44:18 -0800
commitd4eb834824d79c6a64a3c4a1c4a88b434b73e63e (patch)
tree3a6a417a668e79bc588929450f1f7794bb9eee2c /tensorflow/contrib/rnn/python/ops/lstm_ops.py
parent7b306e8fcfb6db3f438c27e437194e78c1d73e23 (diff)
Switch all tf.concat(concat_dim, value, name) calls in third_party/tensorflow to tf.concat_v2(value, axis, name).
Change: 141255675
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/lstm_ops.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 3e30f24310..4ad269ab4f 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -277,7 +277,7 @@ def _LSTMBlockCellGrad(op, *grad):
h_prev_grad.get_shape().merge_with(h_prev.get_shape())
# Backprop from dicfo to w.
- xh = array_ops.concat(1, [x, h_prev])
+ xh = array_ops.concat_v2([x, h_prev], 1)
w_grad = math_ops.matmul(xh, dicfo, transpose_a=True)
w_grad.get_shape().merge_with(w.get_shape())
@@ -527,10 +527,10 @@ class LSTMBlockWrapper(fused_rnn_cell.FusedRNNCell):
# correctly,since we want to access the last valid state at
# sequence_length - 1, which can even be -1, corresponding to the
# initial state.
- mod_cell_states = array_ops.concat(
- 0, [array_ops.expand_dims(initial_cell_state, [0]), cell_states])
- mod_outputs = array_ops.concat(
- 0, [array_ops.expand_dims(initial_output, [0]), outputs])
+ mod_cell_states = array_ops.concat_v2(
+ [array_ops.expand_dims(initial_cell_state, [0]), cell_states], 0)
+ mod_outputs = array_ops.concat_v2(
+ [array_ops.expand_dims(initial_output, [0]), outputs], 0)
final_cell_state = self._gather_states(mod_cell_states, sequence_length,
batch_size)
final_output = self._gather_states(mod_outputs, sequence_length,