aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-26 10:30:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-26 10:33:31 -0700
commitf63a8d6aaf251344631272d6b38327481f54fe55 (patch)
tree34aee2c0988b63304e908703119cb03fb0f0d08c /tensorflow/contrib/cudnn_rnn
parentf495e321026683359fac213b82a20f597d4ead2a (diff)
Remove "everything matched" assertions from CuDNN object-based checkpointing tests
After cl/194315742 the assertions correctly point out that there are some Python objects which aren't matched (they don't have variables). Another option would be to mark these as special/optional, which we can implement if there's a need. PiperOrigin-RevId: 194416864
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 012b17cee8..33ddfb8dee 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -717,7 +717,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
inputs = 3. * array_ops.ones([num_applications, num_layers, input_size],
dtype=dtypes.float32)
cudnn_output, _ = cudnn_layer(inputs)
- status.assert_consumed().run_restore_ops()
+ status.run_restore_ops()
second_save_path = cudnn_checkpoint.save(checkpoint_prefix)
restore_layer = compatible_cell_fn()
restore_layer_checkpoint = checkpointable_utils.Checkpoint(
@@ -728,7 +728,7 @@ class CudnnRNNTestSaveRestoreCheckpointable(test_util.TensorFlowTestCase):
restore_layer_output, current_state = restore_layer(
inputs=3. * array_ops.ones([1, input_size]),
state=current_state)
- status.assert_consumed().run_restore_ops()
+ status.run_restore_ops()
self.assertTrue(restore_layer.variables)
for variable, expected_value in zip(
restore_layer.variables, expected_variable_values):