diff options
Diffstat (limited to 'tensorflow/python/client/session_test.py')
-rw-r--r-- | tensorflow/python/client/session_test.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index bf0a964867..a20376b91d 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1554,6 +1554,33 @@ class SessionTest(test_util.TensorFlowTestCase): sess.run(enqueue_op) self.assertEqual(sess.run(q.size()), num_epochs * 2) + def testRegisterFetchAndFeedConversionFunctions(self): + class SquaredTensor(object): + def __init__(self, tensor): + self.sq = math_ops.square(tensor) + + fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) + feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] + feed_fn2 = lambda feed: [feed.sq] + + session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, + feed_fn1, feed_fn2) + with self.assertRaises(ValueError): + session.register_session_run_conversion_functions(SquaredTensor, + fetch_fn, feed_fn1, feed_fn2) + with self.test_session() as sess: + np1 = np.array([1.0, 1.5, 2.0, 2.5]) + np2 = np.array([3.0, 3.5, 4.0, 4.5]) + squared_tensor = SquaredTensor(np2) + squared_eval = sess.run(squared_tensor) + self.assertAllClose(np2 * np2, squared_eval) + squared_eval = sess.run(squared_tensor, feed_dict={ + squared_tensor : np1 * np1}) + self.assertAllClose(np1 * np1, squared_eval) + partial_run = sess.partial_run_setup([squared_tensor], []) + squared_eval = sess.partial_run(partial_run, squared_tensor) + self.assertAllClose(np2 * np2, squared_eval) + if __name__ == '__main__': googletest.main() |