# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tensorflow.ops.session_ops.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import session_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test class SessionOpsTest(test.TestCase): def testHandleBasic(self): with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) h = sess.run(h) # Feed a tensor handle. f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) y = math_ops.multiply(x, 10) self.assertEqual(500, sess.run(y, feed_dict={f: h.handle})) def testHandleEval(self): with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) h = sess.run(h) # Get the tensor from its handle. self.assertEqual(50, h.eval()) def testHandleAndValue(self): with self.cached_session() as sess: # Return a handle and a value. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) v = math_ops.multiply(a, c) h, v = sess.run([h, v]) self.assertEqual(50, h.eval()) self.assertEqual(500, v) def testHandleCond(self): with self.cached_session() as sess: # Return a handle and a value a = constant_op.constant(10) b = constant_op.constant(5) p = math_ops.less(a, b) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) p, h = sess.run([p, h]) # Run by feeding a tensor handle. f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) if p: y = math_ops.multiply(x, 10) else: y = math_ops.multiply(x, 100) result = sess.run(y, feed_dict={f: h.handle}) self.assertEqual(5000, result) def testHandleForLoop(self): with self.cached_session() as sess: # Initialize a handle. a = constant_op.constant(0) h = session_ops.get_session_handle(a) h = sess.run(h) # Do some computation. f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) # Must define the loop body outside the loop. h_x = session_ops.get_session_handle(math_ops.add(x, 1)) for _ in range(100): # This exercises garbage collection. h = sess.run(h_x, feed_dict={f: h.handle}) self.assertEqual(100, h.eval()) def testHandleWhileLoop(self): with self.cached_session() as sess: # Initialize a handle. a = constant_op.constant(0) h = session_ops.get_session_handle(a) h = sess.run(h) # Do some computation. f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) b = constant_op.constant(100) p = math_ops.less(x, b) # Must define the loop body outside the loop. h_x = session_ops.get_session_handle(math_ops.add(x, 1)) while True: rp, h = sess.run([p, h_x], feed_dict={f: h.handle}) if not rp: break self.assertEqual(101, h.eval()) def testHandleMover(self): with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) h = sess.run(h) # Feed a tensor handle. f, x = session_ops.get_session_tensor(h.handle, dtypes.int32) y = math_ops.multiply(x, 10) self.assertEqual(500, sess.run(y, feed_dict={f: h.handle})) # Feed another tensor handle. with ops.device(test.gpu_device_name()): a = constant_op.constant(10) h = session_ops.get_session_handle(a) h = sess.run(h) self.assertEqual(100, sess.run(y, feed_dict={f: h.handle})) def testHandleDelete(self): with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) sess.run(h).delete() def testHandleDeleteRaw(self): with self.cached_session() as sess: # Return a handle. a = constant_op.constant(10) b = constant_op.constant(5) c = math_ops.multiply(a, b) h = session_ops.get_session_handle(c) h = sess.run(h) # Delete using a raw tensor handle. raw_h = h.get_raw_handle() f, x = session_ops.delete_session_tensor(raw_h) sess.run(x, feed_dict={f: raw_h}) def testMultiDevices(self): with self.cached_session() as sess: with ops.device(test.gpu_device_name()): a = constant_op.constant(1.0) a_handle = sess.run(session_ops.get_session_handle(a)) with ops.device("/cpu:0"): b = constant_op.constant(2.0) b_handle = sess.run(session_ops.get_session_handle(b)) a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32) b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32) c = math_ops.add(a_t, b_t) c_handle = sess.run( session_ops.get_session_handle(c), feed_dict={a_p: a_handle.handle, b_p: b_handle.handle}) self.assertEqual(3.0, c_handle.eval()) def testHandleGC(self): with self.cached_session() as sess: # initial values live on CPU with ops.device("/cpu:0"): one = constant_op.constant(1, dtype=dtypes.float32) one_handle = sess.run(session_ops.get_session_handle(one)) x_handle = sess.run(session_ops.get_session_handle(one)) # addition lives on GPU with ops.device(test.gpu_device_name()): add_h1, add_t1 = session_ops.get_session_tensor(one_handle.handle, dtypes.float32) add_h2, add_t2 = session_ops.get_session_tensor(x_handle.handle, dtypes.float32) add_op = math_ops.add(add_t1, add_t2) add_output = session_ops.get_session_handle(add_op) # add 1 to tensor 20 times for _ in range(20): x_handle = sess.run( add_output, feed_dict={add_h1: one_handle.handle, add_h2: x_handle.handle}) def testHandlePlacement(self): with self.cached_session() as sess: a = constant_op.constant(1.0) a_handle_op = session_ops.get_session_handle(a) b = constant_op.constant(2.0) b_handle_op = session_ops.get_session_handle(b) a_handle = sess.run(a_handle_op) b_handle = sess.run(b_handle_op) a_p, a_t = session_ops.get_session_tensor(a_handle.handle, dtypes.float32) b_p, b_t = session_ops.get_session_tensor(b_handle.handle, dtypes.float32) c = math_ops.add(a_t, b_t) c_handle = sess.run( session_ops.get_session_handle(c), feed_dict={a_p: a_handle.handle, b_p: b_handle.handle}) self.assertEqual(3.0, c_handle.eval()) def testFeedOneHandleDirectly(self): with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) d = math_ops.multiply(c, c) h_c = sess.run(session_ops.get_session_handle(c)) self.assertAllClose(2500.0, sess.run(d, feed_dict={c: h_c})) def testDirectHandleFeedOverlappingWithFetches(self): with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) h_c = sess.run(session_ops.get_session_handle(c)) d = array_ops.identity(c) c_val = sess.run(c, feed_dict={c: h_c}) self.assertAllClose(50.0, c_val) d_val = sess.run(d, feed_dict={c: h_c}) self.assertAllClose(50.0, d_val) c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: 60.0}) self.assertAllClose(50.0, c_val) self.assertAllClose(60.0, d_val) c_val, d_val = sess.run([c, d], feed_dict={c: 60.0, d: h_c}) self.assertAllClose(60.0, c_val) self.assertAllClose(50.0, d_val) c_val, d_val = sess.run([c, d], feed_dict={c: h_c, d: h_c}) self.assertAllClose(50.0, c_val) self.assertAllClose(50.0, d_val) def testFeedTwoHandlesDirectly(self): with self.cached_session() as sess: a = constant_op.constant(10.0) b = constant_op.constant(5.0) c = math_ops.multiply(a, b) d = math_ops.div(a, b) e = math_ops.subtract(c, d) h_c = sess.run(session_ops.get_session_handle(c)) h_d = sess.run(session_ops.get_session_handle(d)) self.assertAllClose(48.0, sess.run(e, feed_dict={c: h_c, d: h_d})) self.assertAllClose(-48.0, sess.run(e, feed_dict={c: h_d, d: h_c})) def testFeedHandleToVariableDirectly(self): with self.cached_session() as sess: a = variables.Variable(12.0) inc_a = state_ops.assign_add(a, 2.0) b = math_ops.add(a, 5.0) sess.run(a.initializer) h_a_read = sess.run(session_ops.get_session_handle(a.read_value())) self.assertAllClose(12.0, sess.run(a)) self.assertAllClose(17.0, sess.run(b, feed_dict={a: h_a_read})) sess.run(inc_a) self.assertAllClose(19.0, sess.run(b, feed_dict={a: h_a_read})) if __name__ == "__main__": test.main()