aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/client/session_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/client/session_test.py')
-rw-r--r--tensorflow/python/client/session_test.py138
1 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index a1a9c57b1c..8073b76b3f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -248,6 +248,144 @@ class SessionTest(test_util.TensorFlowTestCase):
self.assertAllEqual(sp2_out.values, values)
self.assertAllEqual(sp2_out.shape, shape)
+ def testFetchIndexedSlices(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ dense_shape = np.array([7, 9, 2]).astype(np.int64)
+ ind = ops.IndexedSlices(
+ constant_op.constant(values), constant_op.constant(indices),
+ constant_op.constant(dense_shape))
+ # Single fetch, use as tuple
+ ind_out = s.run(ind)
+ values_out, indices_out, dense_shape_out = ind_out
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # Single fetch, use as IndexedSlicesValue
+ ind_out = s.run(ind)
+ self.assertAllEqual(ind_out.values, values)
+ self.assertAllEqual(ind_out.indices, indices)
+ self.assertAllEqual(ind_out.dense_shape, dense_shape)
+ # Tuple fetch, use as tuple
+ values_out, indices_out, dense_shape_out = s.run(ind)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # List fetch, use as tuple
+ (values_out, indices_out, dense_shape_out), = s.run([ind])
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # List fetch, use as IndexedSlicesValue
+ ind_out, = s.run([ind])
+ self.assertAllEqual(ind_out.values, values)
+ self.assertAllEqual(ind_out.indices, indices)
+ self.assertAllEqual(ind_out.dense_shape, dense_shape)
+
+ def testFeedIndexedSlices(self):
+ with session.Session() as s:
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ dense_shape = np.array([7, 9, 2]).astype(np.int64)
+ ind = ops.IndexedSlices(
+ array_ops.placeholder(dtype=np.float32,
+ shape=(2,)),
+ array_ops.placeholder(dtype=np.int64,
+ shape=(2, 3)),
+ array_ops.placeholder(dtype=np.int64,
+ shape=(3,)),)
+ ind_values = array_ops.identity(ind.values)
+ ind_indices = array_ops.identity(ind.indices)
+ ind_dense_shape = array_ops.identity(ind.dense_shape)
+ ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
+ # Feed with tuple
+ values_out, indices_out, dense_shape_out = s.run(
+ [ind_values, ind_indices, ind_dense_shape],
+ {ind: (values, indices, dense_shape)})
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # Feed with IndexedSlicesValue
+ values_out, indices_out, dense_shape_out = s.run(
+ [ind_values, ind_indices, ind_dense_shape],
+ {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
+ ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
+ dense_shape)})
+ self.assertAllEqual(ind2_out.values, values)
+ self.assertAllEqual(ind2_out.indices, indices)
+ self.assertAllEqual(ind2_out.dense_shape, dense_shape)
+
+ def testFetchIndexedSlicesWithoutDenseShape(self):
+ with session.Session() as s:
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ dense_shape = None
+ ind = ops.IndexedSlices(
+ constant_op.constant(values), constant_op.constant(indices), None)
+ # Single fetch, use as tuple
+ ind_out = s.run(ind)
+ values_out, indices_out, dense_shape_out = ind_out
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # Single fetch, use as IndexedSlicesValue
+ ind_out = s.run(ind)
+ self.assertAllEqual(ind_out.values, values)
+ self.assertAllEqual(ind_out.indices, indices)
+ self.assertAllEqual(ind_out.dense_shape, dense_shape)
+ # Tuple fetch, use as tuple
+ values_out, indices_out, dense_shape_out = s.run(ind)
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # List fetch, use as tuple
+ (values_out, indices_out, dense_shape_out), = s.run([ind])
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ self.assertAllEqual(dense_shape_out, dense_shape)
+ # List fetch, use as IndexedSlicesValue
+ ind_out, = s.run([ind])
+ self.assertAllEqual(ind_out.values, values)
+ self.assertAllEqual(ind_out.indices, indices)
+ self.assertAllEqual(ind_out.dense_shape, dense_shape)
+
+ def testFeedIndexedSlicesWithoutDenseShape(self):
+ with session.Session() as s:
+ values = np.array([1.0, 2.0]).astype(np.float32)
+ indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
+ dense_shape = None
+ ind = ops.IndexedSlices(
+ array_ops.placeholder(dtype=np.float32,
+ shape=(2,)),
+ array_ops.placeholder(dtype=np.int64,
+ shape=(2, 3)),
+ None)
+ ind_values = array_ops.identity(ind.values)
+ ind_indices = array_ops.identity(ind.indices)
+ ind2 = ops.IndexedSlices(ind_values, ind_indices)
+ # Feed with tuple
+ values_out, indices_out = s.run(
+ [ind_values, ind_indices], {ind: (values, indices)})
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ # Feed with IndexedSlicesValue
+ values_out, indices_out = s.run(
+ [ind_values, ind_indices],
+ {ind: ops.IndexedSlicesValue(values, indices, dense_shape)})
+ self.assertAllEqual(values_out, values)
+ self.assertAllEqual(indices_out, indices)
+ # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
+ ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices,
+ dense_shape)})
+ self.assertAllEqual(ind2_out.values, values)
+ self.assertAllEqual(ind2_out.indices, indices)
+ self.assertAllEqual(ind2_out.dense_shape, dense_shape)
+
def testExtendWithStatelessOperations(self):
with session.Session() as s:
a = constant_op.constant(1.0, shape=[1, 2])