aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py34
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index f21915ffbc..63fdd91d36 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1585,7 +1585,8 @@ class WeightNormLSTMCellTest(test.TestCase):
with self.test_session() as sess:
init = init_ops.constant_initializer(0.5)
- with variable_scope.variable_scope("root", initializer=init):
+ with variable_scope.variable_scope("root",
+ initializer=init):
x = array_ops.zeros([1, 2])
c0 = array_ops.zeros([1, 2])
h0 = array_ops.zeros([1, 2])
@@ -1595,12 +1596,11 @@ class WeightNormLSTMCellTest(test.TestCase):
xout, sout = cell()(x, state0)
sess.run([variables.global_variables_initializer()])
- res = sess.run(
- [xout, sout], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- })
+ res = sess.run([xout, sout], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ })
actual_state_c = res[1].c
actual_state_h = res[1].h
@@ -1611,8 +1611,9 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell w/o peepholes and w/o normalisation."""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(
- 2, norm=False, use_peepholes=False)
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=False)
actual_c, actual_h = self._cell_output(cell)
@@ -1626,8 +1627,9 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell with peepholes and w/o normalisation."""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(
- 2, norm=False, use_peepholes=True)
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=True)
actual_c, actual_h = self._cell_output(cell)
@@ -1641,8 +1643,9 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell w/o peepholes and with normalisation."""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(
- 2, norm=True, use_peepholes=False)
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=False)
actual_c, actual_h = self._cell_output(cell)
@@ -1656,8 +1659,9 @@ class WeightNormLSTMCellTest(test.TestCase):
"""Tests cell with peepholes and with normalisation."""
def cell():
- return contrib_rnn_cell.WeightNormLSTMCell(
- 2, norm=True, use_peepholes=True)
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=True)
actual_c, actual_h = self._cell_output(cell)