diff options
author | Skye Wanderman-Milne <skyewm@google.com> | 2018-05-17 09:22:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-17 09:25:06 -0700 |
commit | 8e9681486efc504b940683a4d0306c273e6179db (patch) | |
tree | 63b197cdc40f8740a7055fac7dcb98d8d1c2b5fc /tensorflow/python/client | |
parent | 9d3e17b333288f6e1f99f6c62f5469356b3429a6 (diff) |
Update SessionTest.testFeedShapeCompatibility to work with C API enabled.
This test got lost in the transition. Prior to enabling the C API,
some constant node whose values were used for shape inference would be
marked as unfeedable in tensor_util.constant_value
(https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/python/framework/tensor_util.py#L810).
This shape inference path is no longer used with the C API enabled, so
the constant node is successfully fed, triggering a runtime shape error.
This is arguably a regression, but given that the Python code wouldn't
mark all nodes evaluated during shape inference as unfeedable, it
seems ok to relax the check a little more.
PiperOrigin-RevId: 197002741
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r-- | tensorflow/python/client/session_test.py | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index e9a7d9ac1d..482497078c 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1565,10 +1565,6 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertEquals(len(run_metadata.step_stats.dev_stats), 1) def testFeedShapeCompatibility(self): - # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable. - if ops._USE_C_API: - return - with session.Session() as sess: some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) new_shape = constant_op.constant([2, 2]) @@ -1577,7 +1573,10 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'): sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) - with self.assertRaisesRegexp(ValueError, 'may not be fed'): + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + 'Input to reshape is a tensor with 4 values, ' + 'but the requested shape has 21'): sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) def testInferShapesFalse(self): |