aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-29 10:42:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:46:04 -0800
commitfd63d4e30a01cf860baf60b990b223cd54bc895c (patch)
treefcea79b1e89bcf30ac80d087edf051c3711d06b1 /tensorflow/contrib/rnn
parent730071d0dca35a9e08f3bdc49661ae34d109da74 (diff)
Add C0326 bad-whitespace error to pylint sanity check.
PiperOrigin-RevId: 183689499
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py258
1 files changed, 149 insertions, 109 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index cafeb56ad8..e1838c1739 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -42,7 +42,6 @@ from tensorflow.python.platform import test
from tensorflow.python.framework import test_util
from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
-
# pylint: enable=protected-access
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
@@ -84,19 +83,22 @@ class RNNCellTest(test.TestCase):
], [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g], {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
with self.test_session() as sess:
+
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
return getter(*args, **kwargs)
with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5),
+ "root",
+ initializer=init_ops.constant_initializer(0.5),
custom_getter=not_trainable_getter):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
@@ -108,9 +110,10 @@ class RNNCellTest(test.TestCase):
"root/basic_rnn_cell/%s:0" % rnn_cell_impl._BIAS_VARIABLE_NAME
], [v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g], {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
self.assertEqual(res[0].shape, (1, 2))
def testGRUCell(self):
@@ -121,9 +124,10 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2])
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g], {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
# Smoke test
self.assertAllClose(res[0], [[0.175991, 0.175991]])
with variable_scope.variable_scope(
@@ -133,10 +137,10 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2])
g, _ = rnn_cell_impl.GRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g],
- {x.name: np.array([[1., 1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
@@ -148,11 +152,12 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 2])
g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g], {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
# Smoke test
- self.assertAllClose(res[0], [[0.509682, 0.509682]])
+ self.assertAllClose(res[0], [[0.509682, 0.509682]])
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
@@ -164,8 +169,7 @@ class RNNCellTest(test.TestCase):
m = array_ops.zeros([1, 8], dtype=dtype)
cell = rnn_cell_impl.MultiRNNCell(
[
- rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False)
+ rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)
for _ in range(2)
],
state_is_tuple=False)
@@ -183,22 +187,21 @@ class RNNCellTest(test.TestCase):
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" %
rnn_cell_impl._BIAS_VARIABLE_NAME
]
- self.assertEqual(
- expected_variable_names,
- [v.name for v in cell.trainable_variables])
+ self.assertEqual(expected_variable_names,
+ [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g, out_m],
- {x.name: np.array([[1., 1.]]),
- m.name: 0.1 * np.ones([1, 8])})
+ res = sess.run([g, out_m], {
+ x.name: np.array([[1., 1.]]),
+ m.name: 0.1 * np.ones([1, 8])
+ })
self.assertEqual(len(res), 2)
variables = variables_lib.global_variables()
self.assertEqual(expected_variable_names, [v.name for v in variables])
# The numbers in results were not calculated, this is just a
# smoke test.
- self.assertAllClose(
- res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2)
+ self.assertAllClose(res[0], np.array(
+ [[0.240, 0.240]], dtype=np_dtype), 1e-2)
expected_mem = np.array(
[[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]],
dtype=np_dtype)
@@ -208,13 +211,13 @@ class RNNCellTest(test.TestCase):
# Test BasicLSTMCell with input_size != num_units.
x = array_ops.zeros([1, 3], dtype=dtype)
m = array_ops.zeros([1, 4], dtype=dtype)
- g, out_m = rnn_cell_impl.BasicLSTMCell(
- 2, state_is_tuple=False)(x, m)
+ g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
- [g, out_m],
- {x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
- m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)})
+ [g, out_m], {
+ x.name: np.array([[1., 1., 1.]], dtype=np_dtype),
+ m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)
+ })
self.assertEqual(len(res), 2)
def testBasicLSTMCellDimension0Error(self):
@@ -232,9 +235,11 @@ class RNNCellTest(test.TestCase):
g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
- sess.run([g, out_m],
- {x.name: 1 * np.ones([batch_size, input_size]),
- m.name: 0.1 * np.ones([batch_size - 1, state_size])})
+ sess.run(
+ [g, out_m], {
+ x.name: 1 * np.ones([batch_size, input_size]),
+ m.name: 0.1 * np.ones([batch_size - 1, state_size])
+ })
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
@@ -251,9 +256,11 @@ class RNNCellTest(test.TestCase):
g, out_m = rnn_cell_impl.BasicLSTMCell(
num_units, state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
- sess.run([g, out_m],
- {x.name: 1 * np.ones([batch_size, input_size]),
- m.name: 0.1 * np.ones([batch_size, state_size])})
+ sess.run(
+ [g, out_m], {
+ x.name: 1 * np.ones([batch_size, input_size]),
+ m.name: 0.1 * np.ones([batch_size, state_size])
+ })
def testBasicLSTMCellStateTupleType(self):
with self.test_session():
@@ -301,11 +308,12 @@ class RNNCellTest(test.TestCase):
state_is_tuple=True)
g, (out_m0, out_m1) = cell(x, (m0, m1))
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g, out_m0, out_m1], {
- x.name: np.array([[1., 1.]]),
- m0.name: 0.1 * np.ones([1, 4]),
- m1.name: 0.1 * np.ones([1, 4])
- })
+ res = sess.run(
+ [g, out_m0, out_m1], {
+ x.name: np.array([[1., 1.]]),
+ m0.name: 0.1 * np.ones([1, 4]),
+ m1.name: 0.1 * np.ones([1, 4])
+ })
self.assertEqual(len(res), 3)
# The numbers in results were not calculated, this is just a smoke test.
# Note, however, these values should match the original
@@ -336,10 +344,11 @@ class RNNCellTest(test.TestCase):
state_is_tuple=False)
output, state = cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([output, state], {
- x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
- m.name: 0.1 * np.ones((batch_size, state_size))
- })
+ res = sess.run(
+ [output, state], {
+ x.name: np.array([[1., 1.], [2., 2.], [3., 3.]]),
+ m.name: 0.1 * np.ones((batch_size, state_size))
+ })
self.assertEqual(len(res), 2)
# The numbers in results were not calculated, this is mostly just a
# smoke test.
@@ -442,10 +451,10 @@ class RNNCellTest(test.TestCase):
rnn_cell_impl.GRUCell(3), num_proj=3)
g, new_m = cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g, new_m],
- {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1, 0.1]])})
+ res = sess.run([g, new_m], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1, 0.1]])
+ })
self.assertEqual(res[1].shape, (1, 3))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
@@ -479,9 +488,11 @@ class RNNCellTest(test.TestCase):
base_cell = rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables()
+
def residual_with_slice_fn(inp, out):
inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
return inp_sliced + out
+
g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
base_cell, residual_with_slice_fn)(x, m)
sess.run([variables_lib.global_variables_initializer()])
@@ -551,10 +562,10 @@ class RNNCellTest(test.TestCase):
self.assertEqual(embedding_cell.output_size, 2)
g, new_m = embedding_cell(x, m)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g, new_m],
- {x.name: np.array([[1]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g, new_m], {
+ x.name: np.array([[1]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
self.assertEqual(res[1].shape, (1, 2))
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.17139, 0.17139]])
@@ -584,8 +595,8 @@ class RNNCellTest(test.TestCase):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 4])
_, ml = rnn_cell_impl.MultiRNNCell(
- [rnn_cell_impl.GRUCell(2)
- for _ in range(2)], state_is_tuple=False)(x, m)
+ [rnn_cell_impl.GRUCell(2) for _ in range(2)],
+ state_is_tuple=False)(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(ml, {
x.name: np.array([[1., 1.]]),
@@ -605,19 +616,20 @@ class RNNCellTest(test.TestCase):
# Test incorrectness of state
with self.assertRaisesRegexp(ValueError, "Expected state .* a tuple"):
rnn_cell_impl.MultiRNNCell(
- [rnn_cell_impl.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_bad)
+ [rnn_cell_impl.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_bad)
_, ml = rnn_cell_impl.MultiRNNCell(
- [rnn_cell_impl.GRUCell(2)
- for _ in range(2)], state_is_tuple=True)(x, m_good)
+ [rnn_cell_impl.GRUCell(2) for _ in range(2)],
+ state_is_tuple=True)(x, m_good)
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(ml, {
- x.name: np.array([[1., 1.]]),
- m_good[0].name: np.array([[0.1, 0.1]]),
- m_good[1].name: np.array([[0.1, 0.1]])
- })
+ res = sess.run(
+ ml, {
+ x.name: np.array([[1., 1.]]),
+ m_good[0].name: np.array([[0.1, 0.1]]),
+ m_good[1].name: np.array([[0.1, 0.1]])
+ })
# The numbers in results were not calculated, this is just a
# smoke test. However, these numbers should match those of
@@ -628,8 +640,11 @@ class RNNCellTest(test.TestCase):
class DropoutWrapperTest(test.TestCase):
- def _testDropoutWrapper(self, batch_size=None, time_steps=None,
- parallel_iterations=None, **kwargs):
+ def _testDropoutWrapper(self,
+ batch_size=None,
+ time_steps=None,
+ parallel_iterations=None,
+ **kwargs):
with self.test_session() as sess:
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
@@ -640,14 +655,14 @@ class DropoutWrapperTest(test.TestCase):
x = constant_op.constant(
[[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
m = rnn_cell_impl.LSTMStateTuple(
- *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)
- ] * 2)
+ *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32
+ )] * 2)
else:
x = constant_op.constant(
np.random.randn(time_steps, batch_size, 3).astype(np.float32))
m = rnn_cell_impl.LSTMStateTuple(*[
- constant_op.constant(
- [[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
+ constant_op.
+ constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
] * 2)
outputs, final_state = rnn.dynamic_rnn(
cell=rnn_cell_impl.DropoutWrapper(
@@ -674,8 +689,8 @@ class DropoutWrapperTest(test.TestCase):
res = self._testDropoutWrapper(
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
true_full_output = np.array(
- [[[0.751109, 0.751109, 0.751109]],
- [[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
+ [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
+ dtype=np.float32)
true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(true_full_output, res[0])
@@ -687,8 +702,8 @@ class DropoutWrapperTest(test.TestCase):
res = self._testDropoutWrapper(
input_keep_prob=keep, output_keep_prob=keep, state_keep_prob=keep)
true_full_output = np.array(
- [[[0.751109, 0.751109, 0.751109]],
- [[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
+ [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
+ dtype=np.float32)
true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(true_full_output, res[0])
@@ -703,16 +718,20 @@ class DropoutWrapperTest(test.TestCase):
## consistent across both calls. Otherwise the seed may not end
## up being munged consistently across both graphs.
res_standard_1 = self._testDropoutWrapper(
- input_keep_prob=keep_some, output_keep_prob=keep_some,
- state_keep_prob=keep_some, seed=10,
+ input_keep_prob=keep_some,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_some,
+ seed=10,
parallel_iterations=1)
# Clear away the graph and the test session (which keeps variables around)
ops.reset_default_graph()
self._ClearCachedSession()
random_seed.set_random_seed(2)
res_standard_2 = self._testDropoutWrapper(
- input_keep_prob=keep_some, output_keep_prob=keep_some,
- state_keep_prob=keep_some, seed=10,
+ input_keep_prob=keep_some,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_some,
+ seed=10,
parallel_iterations=1)
self.assertAllClose(res_standard_1[0], res_standard_2[0])
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
@@ -722,11 +741,12 @@ class DropoutWrapperTest(test.TestCase):
keep_all = variable_scope.get_variable("all", initializer=1.0)
keep_none = variable_scope.get_variable("none", initializer=1e-10)
res = self._testDropoutWrapper(
- input_keep_prob=keep_all, output_keep_prob=keep_none,
+ input_keep_prob=keep_all,
+ output_keep_prob=keep_none,
state_keep_prob=keep_all)
true_full_output = np.array(
- [[[0.751109, 0.751109, 0.751109]],
- [[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
+ [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
+ dtype=np.float32)
true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
self.assertAllClose(np.zeros(res[0].shape), res[0])
@@ -739,13 +759,13 @@ class DropoutWrapperTest(test.TestCase):
# Even though we dropout state, by default DropoutWrapper never
# drops out the memory ("c") term of an LSTMStateTuple.
res = self._testDropoutWrapper(
- input_keep_prob=keep_all, output_keep_prob=keep_all,
+ input_keep_prob=keep_all,
+ output_keep_prob=keep_all,
state_keep_prob=keep_none)
- true_c_state = np.array(
- [[1.713925, 1.713925, 1.713925]], dtype=np.float32)
+ true_c_state = np.array([[1.713925, 1.713925, 1.713925]], dtype=np.float32)
true_full_output = np.array(
- [[[0.751109, 0.751109, 0.751109]],
- [[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
+ [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
+ dtype=np.float32)
self.assertAllClose(true_full_output[0], res[0][0])
# Second output is modified by zero input state
self.assertGreater(np.linalg.norm(true_full_output[1] - res[0][1]), 1e-4)
@@ -758,13 +778,14 @@ class DropoutWrapperTest(test.TestCase):
keep_all = variable_scope.get_variable("all", initializer=1.0)
keep_none = variable_scope.get_variable("none", initializer=1e-10)
true_full_output = np.array(
- [[[0.751109, 0.751109, 0.751109]],
- [[0.895509, 0.895509, 0.895509]]], dtype=np.float32)
+ [[[0.751109, 0.751109, 0.751109]], [[0.895509, 0.895509, 0.895509]]],
+ dtype=np.float32)
true_full_final_c = np.array(
[[1.949385, 1.949385, 1.949385]], dtype=np.float32)
# All outputs are different because inputs are zeroed out
res = self._testDropoutWrapper(
- input_keep_prob=keep_none, output_keep_prob=keep_all,
+ input_keep_prob=keep_none,
+ output_keep_prob=keep_all,
state_keep_prob=keep_all)
self.assertGreater(np.linalg.norm(res[0] - true_full_output), 1e-4)
self.assertGreater(np.linalg.norm(res[1].h - true_full_output[1]), 1e-4)
@@ -774,9 +795,13 @@ class DropoutWrapperTest(test.TestCase):
keep_some = 0.8
keep_all = variable_scope.get_variable("all", initializer=1.0)
res = self._testDropoutWrapper(
- input_keep_prob=keep_all, output_keep_prob=keep_some,
- state_keep_prob=keep_all, variational_recurrent=True,
- input_size=3, batch_size=5, time_steps=7)
+ input_keep_prob=keep_all,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_all,
+ variational_recurrent=True,
+ input_size=3,
+ batch_size=5,
+ time_steps=7)
# Ensure the same dropout pattern for all time steps
output_mask = np.abs(res[0]) > 1e-6
for m in output_mask[1:]:
@@ -785,9 +810,13 @@ class DropoutWrapperTest(test.TestCase):
def testDropoutWrapperRecurrentStateInputAndOutput(self):
keep_some = 0.9
res = self._testDropoutWrapper(
- input_keep_prob=keep_some, output_keep_prob=keep_some,
- state_keep_prob=keep_some, variational_recurrent=True,
- input_size=3, batch_size=5, time_steps=7)
+ input_keep_prob=keep_some,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_some,
+ variational_recurrent=True,
+ input_size=3,
+ batch_size=5,
+ time_steps=7)
# Smoke test for the state/input masks.
output_mask = np.abs(res[0]) > 1e-6
@@ -811,17 +840,27 @@ class DropoutWrapperTest(test.TestCase):
random_seed.set_random_seed(2347)
np.random.seed(23487)
res0 = self._testDropoutWrapper(
- input_keep_prob=keep_some, output_keep_prob=keep_some,
- state_keep_prob=keep_some, variational_recurrent=True,
- input_size=3, batch_size=5, time_steps=7, seed=-234987)
+ input_keep_prob=keep_some,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_some,
+ variational_recurrent=True,
+ input_size=3,
+ batch_size=5,
+ time_steps=7,
+ seed=-234987)
ops.reset_default_graph()
self._ClearCachedSession()
random_seed.set_random_seed(2347)
np.random.seed(23487)
res1 = self._testDropoutWrapper(
- input_keep_prob=keep_some, output_keep_prob=keep_some,
- state_keep_prob=keep_some, variational_recurrent=True,
- input_size=3, batch_size=5, time_steps=7, seed=-234987)
+ input_keep_prob=keep_some,
+ output_keep_prob=keep_some,
+ state_keep_prob=keep_some,
+ variational_recurrent=True,
+ input_size=3,
+ batch_size=5,
+ time_steps=7,
+ seed=-234987)
output_mask = np.abs(res0[0]) > 1e-6
for time_step in output_mask:
@@ -858,9 +897,10 @@ class SlimRNNCellTest(test.TestCase):
g, _ = rnn_cell_impl._SlimRNNCell(my_cell)(x, m)
# pylint: enable=protected-access
sess.run([variables_lib.global_variables_initializer()])
- res = sess.run(
- [g], {x.name: np.array([[1., 1.]]),
- m.name: np.array([[0.1, 0.1]])})
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellMatch(self):