diff options
author | 2017-02-28 16:40:43 -0800 | |
---|---|---|
committer | 2017-02-28 17:02:09 -0800 | |
commit | 4ee4d1de7cf8c4c6f146a5b81baf288df17c944f (patch) | |
tree | 4f51f760c20c32fd00049fc5ee03801642e5f994 | |
parent | 676f94a952df76d36a134a5e3b92285fabb72cc8 (diff) |
Cleaner error messages for RNNCell scope failures.
Change: 148832474
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py | 6 |
2 files changed, 26 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 8b0a9a2bbf..b38f08d4d5 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -344,6 +344,29 @@ class RNNCellTest(test.TestCase): outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) + def testUsingSecondCellInScopeWithExistingVariablesFails(self): + # This test should go away when this behavior is no longer an + # error (Approx. May 2017) + cell1 = core_rnn_cell_impl.LSTMCell(3) + cell2 = core_rnn_cell_impl.LSTMCell(3) + x = array_ops.zeros([1, 3]) + m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) + cell1(x, m) + with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"): + cell2(x, m) + + def testUsingCellInDifferentScopeFromFirstCallFails(self): + # This test should go away when this behavior is no longer an + # error (Approx. May 2017) + cell = core_rnn_cell_impl.LSTMCell(3) + x = array_ops.zeros([1, 3]) + m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2) + with variable_scope.variable_scope("scope1"): + cell(x, m) + with variable_scope.variable_scope("scope2"): + with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"): + cell(x, m) + def testDropoutWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 75c8cdc6ae..148d2ef93a 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -81,11 +81,11 @@ def _checked_scope(cell, scope, reuse=None, **kwargs): if weights_found and reuse is None: raise ValueError( "Attempt to have a second RNNCell use the weights of a variable " - "scope that already has weights: '%s' (and RNNCell was not " - "constructed with reuse=True). " + "scope that already has weights: '%s'; and the cell was not " + "constructed as %s(..., reuse=True). " "To share the weights of an RNNCell, simply " "reuse it in your second calculation, or create a new one with " - "the argument reuse=True." % scope_name) + "the argument reuse=True." % (scope_name, type(cell).__name__)) # Everything is OK. Update the cell's scope and yield it. cell._scope = checking_scope # pylint: disable=protected-access |