aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client
diff options
context:
space:
mode:
authorGravatar Skye Wanderman-Milne <skyewm@google.com>2018-05-17 09:22:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 09:25:06 -0700
commit8e9681486efc504b940683a4d0306c273e6179db (patch)
tree63b197cdc40f8740a7055fac7dcb98d8d1c2b5fc /tensorflow/python/client
parent9d3e17b333288f6e1f99f6c62f5469356b3429a6 (diff)
Update SessionTest.testFeedShapeCompatibility to work with C API enabled.
This test got lost in the transition. Prior to enabling the C API, some constant node whose values were used for shape inference would be marked as unfeedable in tensor_util.constant_value (https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/python/framework/tensor_util.py#L810). This shape inference path is no longer used with the C API enabled, so the constant node is successfully fed, triggering a runtime shape error. This is arguably a regression, but given that the Python code wouldn't mark all nodes evaluated during shape inference as unfeedable, it seems ok to relax the check a little more. PiperOrigin-RevId: 197002741
Diffstat (limited to 'tensorflow/python/client')
-rw-r--r--tensorflow/python/client/session_test.py9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index e9a7d9ac1d..482497078c 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1565,10 +1565,6 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertEquals(len(run_metadata.step_stats.dev_stats), 1)
def testFeedShapeCompatibility(self):
- # TODO(nolivia): C API doesn't yet handle marking nodes as not feedable.
- if ops._USE_C_API:
- return
-
with session.Session() as sess:
some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
new_shape = constant_op.constant([2, 2])
@@ -1577,7 +1573,10 @@ class SessionTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'Cannot feed value of shape'):
sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
- with self.assertRaisesRegexp(ValueError, 'may not be fed'):
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ 'Input to reshape is a tensor with 4 values, '
+ 'but the requested shape has 21'):
sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
def testInferShapesFalse(self):