diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-11 20:19:50 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-09-11 20:19:50 +0800 |
commit | b2896c3cc3a0656b838f58975338d7dd309e3e62 (patch) | |
tree | 14f25741ab43c15e945e6044833c0ff44f11d83f /tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | |
parent | 38f811077dd52820eaa3d5c684f41142de01c7eb (diff) | |
parent | e18f84a394bcbde62b344a3b32e8d8fd248fea58 (diff) |
Merge remote-tracking branch 'upstream/master' into ENH/div_no_nan_treate_negative_as_zero
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 73 |
1 files changed, 36 insertions, 37 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 15ce9d1ce7..be0306cb07 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -48,7 +48,7 @@ Linear = core_rnn_cell._Linear # pylint: disable=invalid-name class RNNCellTest(test.TestCase): def testLinear(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(1.0)): x = array_ops.zeros([1, 2]) @@ -69,7 +69,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(variables_lib.trainable_variables()), 2) def testBasicRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -89,7 +89,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testBasicRNNCellNotTrainable(self): - with self.test_session() as sess: + with self.cached_session() as sess: def not_trainable_getter(getter, *args, **kwargs): kwargs["trainable"] = False @@ -116,7 +116,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testIndRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -137,7 +137,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(res[0].shape, (1, 2)) def testGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -165,7 +165,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.156736, 0.156736]]) def testIndyGRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -193,7 +193,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.155127, 0.157328]]) def testSRUCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -208,7 +208,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.509682, 0.509682]]) def testSRUCellWithDiffSize(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -288,7 +288,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellDimension0Error(self): """Tests that dimension 0 in both(x and m) shape must be equal.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -309,7 +309,7 @@ class RNNCellTest(test.TestCase): def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 @@ -329,7 +329,7 @@ class RNNCellTest(test.TestCase): }) def testBasicLSTMCellStateTupleType(self): - with self.test_session(): + with self.cached_session(): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -360,7 +360,7 @@ class RNNCellTest(test.TestCase): self.assertTrue(isinstance(out_m1, rnn_cell_impl.LSTMStateTuple)) def testBasicLSTMCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -459,7 +459,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(len(res), 2) def testLSTMCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -494,7 +494,7 @@ class RNNCellTest(test.TestCase): float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) > 1e-6) def testLSTMCellVariables(self): - with self.test_session(): + with self.cached_session(): num_units = 8 num_proj = 6 state_size = num_units + num_proj @@ -517,7 +517,7 @@ class RNNCellTest(test.TestCase): "root/lstm_cell/projection/kernel") def testLSTMCellLayerNorm(self): - with self.test_session() as sess: + with self.cached_session() as sess: num_units = 2 num_proj = 3 batch_size = 1 @@ -562,22 +562,21 @@ class RNNCellTest(test.TestCase): rnn_cell_impl.DropoutWrapper, rnn_cell_impl.ResidualWrapper, lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: - with self.test_session(): - cell = rnn_cell_impl.BasicRNNCell(1) - wrapper = wrapper_type(cell) - wrapper(array_ops.ones([1, 1]), - state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) - self.evaluate([v.initializer for v in cell.variables]) - checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(cell._bias.assign([40.])) - save_path = checkpoint.save(prefix) - self.evaluate(cell._bias.assign([0.])) - checkpoint.restore(save_path).assert_consumed().run_restore_ops() - self.assertAllEqual([40.], self.evaluate(cell._bias)) + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) def testOutputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -594,7 +593,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.231907, 0.231907]]) def testInputProjectionWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -612,7 +611,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]]) def testResidualWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) @@ -638,7 +637,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[2], res[3]) def testResidualWrapperWithSlice(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 5]) @@ -716,7 +715,7 @@ class RNNCellTest(test.TestCase): self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) def testEmbeddingWrapper(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 1], dtype=dtypes.int32) @@ -735,7 +734,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.17139, 0.17139]]) def testEmbeddingWrapperWithDynamicRnn(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope("root"): inputs = ops.convert_to_tensor([[[0], [0]]], dtype=dtypes.int64) input_lengths = ops.convert_to_tensor([2], dtype=dtypes.int64) @@ -753,7 +752,7 @@ class RNNCellTest(test.TestCase): sess.run(outputs) def testMultiRNNCell(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -770,7 +769,7 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res, [[0.175991, 0.175991, 0.13248, 0.13248]]) def testMultiRNNCellWithStateTuple(self): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 2]) @@ -809,7 +808,7 @@ class DropoutWrapperTest(test.TestCase): time_steps=None, parallel_iterations=None, **kwargs): - with self.test_session() as sess: + with self.cached_session() as sess: with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): if batch_size is None and time_steps is None: |