diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index f21915ffbc..63fdd91d36 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1585,7 +1585,8 @@ class WeightNormLSTMCellTest(test.TestCase): with self.test_session() as sess: init = init_ops.constant_initializer(0.5) - with variable_scope.variable_scope("root", initializer=init): + with variable_scope.variable_scope("root", + initializer=init): x = array_ops.zeros([1, 2]) c0 = array_ops.zeros([1, 2]) h0 = array_ops.zeros([1, 2]) @@ -1595,12 +1596,11 @@ class WeightNormLSTMCellTest(test.TestCase): xout, sout = cell()(x, state0) sess.run([variables.global_variables_initializer()]) - res = sess.run( - [xout, sout], { - x.name: np.array([[1., 1.]]), - c0.name: 0.1 * np.asarray([[0, 1]]), - h0.name: 0.1 * np.asarray([[2, 3]]), - }) + res = sess.run([xout, sout], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + }) actual_state_c = res[1].c actual_state_h = res[1].h @@ -1611,8 +1611,9 @@ class WeightNormLSTMCellTest(test.TestCase): """Tests cell w/o peepholes and w/o normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=False) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=False) actual_c, actual_h = self._cell_output(cell) @@ -1626,8 +1627,9 @@ class WeightNormLSTMCellTest(test.TestCase): """Tests cell with peepholes and w/o normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=False, use_peepholes=True) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=True) actual_c, actual_h = self._cell_output(cell) @@ -1641,8 +1643,9 @@ class WeightNormLSTMCellTest(test.TestCase): """Tests cell w/o peepholes and with normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=False) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=False) actual_c, actual_h = self._cell_output(cell) @@ -1656,8 +1659,9 @@ class WeightNormLSTMCellTest(test.TestCase): """Tests cell with peepholes and with normalisation.""" def cell(): - return contrib_rnn_cell.WeightNormLSTMCell( - 2, norm=True, use_peepholes=True) + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=True) actual_c, actual_h = self._cell_output(cell) |