aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client/session_test.py')
-rw-r--r--tensorflow/python/client/session_test.py555
1 files changed, 555 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
new file mode 100644
index 0000000000..4492840dcf
--- /dev/null
+++ b/tensorflow/python/client/session_test.py
@@ -0,0 +1,555 @@
+"""Tests for tensorflow.python.client.session.Session."""
+import threading
+import time
+
+import tensorflow.python.platform
+
+import numpy as np
+
+from tensorflow.core.framework import config_pb2
+from tensorflow.core.lib.core import error_codes_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.framework import types
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+# NOTE(mrry): Dummy shape registration for op used in the tests.
+ops.RegisterShape('ConstructionFails')(None)
+
+
+class SessionTest(test_util.TensorFlowTestCase):
+
+ def testUseExistingGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session(graph=g):
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testUseDefaultGraph(self):
+ with ops.Graph().as_default(), ops.device('/cpu:0'):
+ a = constant_op.constant(6.0, shape=[1, 1])
+ b = constant_op.constant(7.0, shape=[1, 1])
+ c = math_ops.matmul(a, b, name='matmul')
+ with session.Session():
+ result = c.eval()
+ self.assertAllEqual(result, [[42.0]])
+
+ def testCreate(self):
+ with session.Session():
+ inp = constant_op.constant(10.0, name='W1')
+ copy = array_ops.identity(inp)
+ # Test with feed.
+ # TODO(mrry): Investigate why order='F' didn't work.
+ arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
+ copy_val = copy.eval({'W1:0': arr})
+ self.assertAllEqual(arr, copy_val)
+ # Test without feed.
+ copy_val = copy.eval()
+ self.assertAllEqual(np.asarray(10.0, dtype=np.float32), copy_val)
+
+ def testManyCPUs(self):
+ # TODO(keveman): Implement ListDevices and test for the number of
+ # devices returned by ListDevices.
+ with session.Session(
+ config=config_pb2.ConfigProto(device_count={'CPU': 2})):
+ inp = constant_op.constant(10.0, name='W1')
+ self.assertAllEqual(inp.eval(), 10.0)
+
+ def testErrorsReported(self):
+ with session.Session() as s:
+ constant_op.constant(10.0, name='W1')
+ with self.assertRaises(ValueError):
+ s.run('foo:0')
+
+ def testErrorPayload(self):
+ with session.Session():
+ a = array_ops.placeholder(types.float32)
+ with self.assertRaisesOpError(lambda e: e.op == a.op):
+ a.eval()
+
+ def testOpConstructionErrorPayload(self):
+ with session.Session():
+ failing_op = ops.get_default_graph().create_op(
+ 'ConstructionFails', [], [], name='f')
+
+ def exc_predicate(e):
+ return (e.op == failing_op
+ and e.error_code == error_codes_pb2.INVALID_ARGUMENT)
+ with self.assertRaisesOpError(exc_predicate):
+ failing_op.run()
+
+ def testErrorBasedOn(self):
+ with session.Session() as sess:
+ a = constant_op.constant(0.0, shape=[2, 3])
+ # NOTE(mrry): The original_op is nonsense, but used here to test that the
+ # errors are reported correctly.
+ # pylint: disable=protected-access
+ with sess.graph._original_op(a.op):
+ b = array_ops.identity(a, name='id')
+ with sess.graph._original_op(b.op):
+ c = array_ops.placeholder(types.float32)
+ # pylint: enable=protected-access
+
+ def exc_predicate(e):
+ return (e.op == c.op
+ and e.op._original_op == b.op
+ and e.op._original_op._original_op == a.op)
+ with self.assertRaisesOpError(exc_predicate):
+ c.eval()
+
+ def testFetchTensorObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ results_with_list = s.run([c])
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
+ results_with_single = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
+ results_with_get = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
+ a_val, b_val = s.run([a, b]) # Test multiple fetches.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+ self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
+
+ def testFetchScalar(self):
+ with session.Session() as s:
+ for scalar in np.int32, np.int64, np.float32, np.float64:
+ x = scalar(7)
+ y = scalar(8)
+ tf_x = constant_op.constant(x, shape=[])
+ tf_y = constant_op.constant(y)
+ tf_xy = math_ops.add(tf_x, tf_y)
+ # Single fetch
+ xy = s.run(tf_xy)
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+ # List fetch
+ xy, = s.run([tf_xy])
+ self.assertEqual(scalar, type(xy))
+ self.assertEqual(x + y, xy)
+
+ def testFetchOperationObject(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ v = variables.Variable(a, name='testFetchOperationObject_v')
+ s.run(v.initializer)
+ v_val = s.run(v)
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ def testFetchSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ constant_op.constant(indices),
+ constant_op.constant(values),
+ constant_op.constant(shape))
+ # Single fetch, use as tuple
+ sp_out = s.run(sp)
+ indices_out, values_out, shape_out = sp_out
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Single fetch, use as SparseTensorValue
+ sp_out = s.run(sp)
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+ # Tuple fetch, use as tuple
+ indices_out, values_out, shape_out = s.run(sp)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as tuple
+ (indices_out, values_out, shape_out), = s.run([sp])
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # List fetch, use as SparseTensorValue
+ sp_out, = s.run([sp])
+ self.assertAllEqual(sp_out.indices, indices)
+ self.assertAllEqual(sp_out.values, values)
+ self.assertAllEqual(sp_out.shape, shape)
+
+ def testFeedSparseTensor(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ shape = np.array([7, 9, 2]).astype(np.int64)
+ sp = ops.SparseTensor(
+ array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
+ array_ops.placeholder(dtype=np.float32, shape=(2,)),
+ array_ops.placeholder(dtype=np.int64, shape=(3,)),)
+ sp_indices = array_ops.identity(sp.indices)
+ sp_values = array_ops.identity(sp.values)
+ sp_shape = array_ops.identity(sp.shape)
+ sp2 = ops.SparseTensor(sp_indices, sp_values, sp_shape)
+ # Feed with tuple
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape], {sp: (indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue
+ indices_out, values_out, shape_out = s.run(
+ [sp_indices, sp_values, sp_shape],
+ {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(shape_out, shape)
+ # Feed with SparseTensorValue, fetch SparseTensorValue
+ sp2_out = s.run(sp2, {sp: ops.SparseTensorValue(indices, values, shape)})
+ self.assertAllEqual(sp2_out.indices, indices)
+ self.assertAllEqual(sp2_out.values, values)
+ self.assertAllEqual(sp2_out.shape, shape)
+
+ def testExtendWithStatelessOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ c_val = s.run(c)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ # Extend will happen here.
+ e_val = s.run(e)
+ self.assertAllEqual([[24.0]], e_val)
+
+ def testExtendWithStatefulOperations(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ # Extend will happen here.
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+
+ def testExtendWithGroupBy(self):
+ with session.Session() as s:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ p = variables.Variable(a, name='testExtendWithGroupBy_p')
+ a_val = a.eval() # Force an Extend after this op.
+ self.assertAllEqual([[1.0, 1.0]], a_val)
+
+ b = constant_op.constant(2.0, shape=[1, 2])
+ q = variables.Variable(b, name='testExtendWithGroupBy_q')
+ # Extend will happen here.
+ init = control_flow_ops.group(p.initializer, q.initializer)
+ s.run(init)
+ p_val, q_val = s.run([p, q])
+
+ self.assertAllEqual([[1.0, 1.0]], p_val)
+ self.assertAllEqual([[2.0, 2.0]], q_val)
+
+ def testTensorGetMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ c_val = c.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
+
+ fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
+ self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
+
+ def testOperationRunMethod(self):
+ with session.Session():
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 2], name='b')
+ v = variables.Variable(a, a.dtype)
+ assign_a_to_v = state_ops.assign(v, a)
+
+ assign_a_to_v.eval()
+
+ v_val = v.eval()
+ self.assertAllEqual([[1.0, 1.0]], v_val)
+
+ assign_b_to_v = state_ops.assign(v, b)
+
+ assign_b_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[2.0, 2.0]], v_val)
+
+ assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
+ v_val = v.eval()
+ self.assertAllEqual([[3.0, 3.0]], v_val)
+
+ def testDefaultGraph(self):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ self.assertEqual(ops.get_default_graph(), a.graph)
+ self.assertEqual(ops.get_default_graph(), b.graph)
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='testDefaultGraph_v')
+ v.initializer.run()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
+ with session.Session() as s:
+ self.assertEqual(ops.get_default_graph(), s.graph)
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ v = variables.Variable(c, name='var_%d' % i)
+
+ # Block here until all threads have constructed their graph.
+ constructed_event.set()
+ continue_event.wait()
+
+ assign_c_to_v = state_ops.assign(v, c)
+ v.initializer.run()
+ assign_c_to_v.eval()
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ d = constant_op.constant(3.0, shape=[2, 3])
+ e = math_ops.matmul(a, d)
+ assign_e_to_v = state_ops.assign(v, e)
+ e_val = e.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
+ v_val = v.eval()
+ self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
+ s.run(assign_e_to_v)
+ v_val = v.eval()
+ self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
+ self.assertEqual(ops.get_default_graph(), s.graph)
+
+ def testDefaultGraphWithThreads(self):
+ # Fork ten threads that use their thread-local default graph.
+ threads = []
+ constructed_events = [threading.Event() for _ in range(10)]
+ continue_event = threading.Event()
+ for i, constructed_event in enumerate(constructed_events):
+ t = self.checkedThread(target=self._testDefaultGraphInThread,
+ args=(constructed_event, continue_event, i))
+ threads.append(t)
+ for t in threads:
+ t.start()
+ for constructed_event in constructed_events:
+ constructed_event.wait()
+ continue_event.set()
+ for t in threads:
+ t.join()
+
+ def testParallelRun(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ ev = threading.Event()
+
+ def run_step():
+ ev.wait()
+ val = c.eval(session=sess)
+ self.assertEqual(val, 5.0)
+ threads = [self.checkedThread(target=run_step) for _ in range(100)]
+ for t in threads:
+ t.start()
+ ev.set()
+ for t in threads:
+ t.join()
+
+ def testRunFeedDict(self):
+ with session.Session() as s:
+ x = array_ops.zeros([2])
+
+ y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
+ self.assertAllEqual(y, 2 * np.ones(2))
+
+ y = s.run(2 * x, feed_dict={x: [1, 1]})
+ assert (y == 2 * np.ones(2)).all()
+
+ def testGraphDef(self):
+ with session.Session() as sess:
+ self.assertProtoEquals('', sess.graph_def)
+ c = constant_op.constant(5.0, name='c')
+ self.assertEquals(len(sess.graph_def.node), 1)
+ d = constant_op.constant(6.0, name='d')
+ self.assertEquals(len(sess.graph_def.node), 2)
+ self.assertAllEqual(c.eval(), 5.0)
+ self.assertAllEqual(d.eval(), 6.0)
+ e = constant_op.constant(7.0, name='e')
+ self.assertEquals(len(sess.graph_def.node), 3)
+ self.assertAllEqual(e.eval(), 7.0)
+
+ def testUseAfterClose(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
+ sess.run(c)
+
+ def testUseAfterCloseConcurrent(self):
+ with session.Session() as sess:
+ c = constant_op.constant(5.0)
+ self.assertAllEqual(sess.run(c), 5.0)
+
+ def update_thread():
+ with self.assertRaisesWithPredicateMatch(
+ RuntimeError,
+ lambda e: 'Attempted to use a closed Session.' in str(e)):
+ while True:
+ sess.run(c)
+ t = threading.Thread(target=update_thread)
+ t.start()
+ time.sleep(0.1)
+ sess.close()
+ t.join()
+
+ def testNotEntered(self):
+ # pylint: disable=protected-access
+ self.assertEqual(ops._default_session_stack.get_default(), None)
+ # pylint: enable=protected-access
+ with ops.device('/cpu:0'):
+ sess = session.Session()
+ c_1 = constant_op.constant(5.0)
+ with sess.graph.as_default():
+ c_2 = constant_op.constant(5.0)
+ self.assertEqual(c_1.graph, c_2.graph)
+ self.assertEqual(sess.run(c_2), 5.0)
+ with self.assertRaisesWithPredicateMatch(
+ ValueError, lambda e: 'No default session is registered.' in str(e)):
+ c_2.eval()
+
+ def testInteractive(self):
+ with ops.device('/cpu:0'):
+ sess = session.InteractiveSession()
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+ self.assertAllEqual([[4.0, 4.0, 4.0]], c.eval())
+ d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
+ e = math_ops.matmul(c, d)
+ self.assertAllEqual([[24.0]], e.eval())
+ sess.close()
+
+ def testSharedGraph(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[2, 3])
+ c = math_ops.matmul(a, b)
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ self.assertAllEqual(sess1.run(c), sess2.run(c))
+
+ def testDuplicatedInputs(self):
+ with session.Session() as sess:
+ a = constant_op.constant(1.0, shape=[1, 2])
+ b = constant_op.constant(2.0, shape=[1, 3])
+ a_val, b_val, a2_val = sess.run([a, b, a])
+ self.assertAllEqual(a_val, [[1.0, 1.0]])
+ self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
+ self.assertAllEqual(a2_val, [[1.0, 1.0]])
+
+ def testFeedAndFetch(self):
+ with session.Session():
+ for dtype in [types.float32,
+ types.float64,
+ types.int32,
+ types.uint8,
+ types.int16,
+ types.int8,
+ types.int64,
+ types.bool,
+ types.complex64]:
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ np_dtype = dtype.as_numpy_dtype
+
+ feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
+ out_t = array_ops.identity(feed_t)
+
+ np_array = np.random.randint(-10, 10, shape)
+
+ if dtype == types.bool:
+ np_array = np_array > 0
+ elif dtype == types.complex64:
+ np_array = np.sqrt(np_array.astype(np_dtype))
+ else:
+ np_array = np_array.astype(np_dtype)
+
+ self.assertAllEqual(np_array,
+ out_t.eval(feed_dict={feed_t: np_array}))
+
+ def testStringFetch(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape) if size > 0 else []
+ c = constant_op.constant(c_list)
+ self.assertAllEqual(c.eval(), c_list)
+
+ def testStringFeed(self):
+ with session.Session():
+ for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
+ size = 1
+ for s in shape:
+ size *= s
+ c_list = np.array([str(i) for i in xrange(size)],
+ dtype=np.object).reshape(shape)
+ feed_t = array_ops.placeholder(dtype=types.string, shape=shape)
+ c = array_ops.identity(feed_t)
+ self.assertAllEqual(c.eval(feed_dict={feed_t: c_list}), c_list)
+
+ def testStringFeedWithNullCharacters(self):
+ with session.Session():
+ c_list = ['\n\x01\x00', '\n\x00\x01']
+ feed_t = array_ops.placeholder(dtype=types.string, shape=[2])
+ c = array_ops.identity(feed_t)
+ out = c.eval(feed_dict={feed_t: c_list})
+ self.assertEqual(c_list[0], out[0])
+ self.assertEqual(c_list[1], out[1])
+
+ def testInvalidTargetFails(self):
+ with self.assertRaises(RuntimeError):
+ session.Session("INVALID_TARGET")
+
+
+if __name__ == '__main__':
+ googletest.main()