diff options
Diffstat (limited to 'tensorflow/python/client/session_test.py')
-rw-r--r-- | tensorflow/python/client/session_test.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index a1a9c57b1c..8073b76b3f 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -248,6 +248,144 @@ class SessionTest(test_util.TensorFlowTestCase): self.assertAllEqual(sp2_out.values, values) self.assertAllEqual(sp2_out.shape, shape) + def testFetchIndexedSlices(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + dense_shape = np.array([7, 9, 2]).astype(np.int64) + ind = ops.IndexedSlices( + constant_op.constant(values), constant_op.constant(indices), + constant_op.constant(dense_shape)) + # Single fetch, use as tuple + ind_out = s.run(ind) + values_out, indices_out, dense_shape_out = ind_out + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Single fetch, use as IndexedSlicesValue + ind_out = s.run(ind) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + # Tuple fetch, use as tuple + values_out, indices_out, dense_shape_out = s.run(ind) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as tuple + (values_out, indices_out, dense_shape_out), = s.run([ind]) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as IndexedSlicesValue + ind_out, = s.run([ind]) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + + def testFeedIndexedSlices(self): + with session.Session() as s: + values = np.array([1.0, 2.0]).astype(np.float32) + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + dense_shape = np.array([7, 9, 2]).astype(np.int64) + ind = ops.IndexedSlices( + array_ops.placeholder(dtype=np.float32, + shape=(2,)), + array_ops.placeholder(dtype=np.int64, + shape=(2, 3)), + array_ops.placeholder(dtype=np.int64, + shape=(3,)),) + ind_values = array_ops.identity(ind.values) + ind_indices = array_ops.identity(ind.indices) + ind_dense_shape = array_ops.identity(ind.dense_shape) + ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) + # Feed with tuple + values_out, indices_out, dense_shape_out = s.run( + [ind_values, ind_indices, ind_dense_shape], + {ind: (values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Feed with IndexedSlicesValue + values_out, indices_out, dense_shape_out = s.run( + [ind_values, ind_indices, ind_dense_shape], + {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Feed with IndexedSlicesValue, fetch IndexedSlicesValue + ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, + dense_shape)}) + self.assertAllEqual(ind2_out.values, values) + self.assertAllEqual(ind2_out.indices, indices) + self.assertAllEqual(ind2_out.dense_shape, dense_shape) + + def testFetchIndexedSlicesWithoutDenseShape(self): + with session.Session() as s: + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + values = np.array([1.0, 2.0]).astype(np.float32) + dense_shape = None + ind = ops.IndexedSlices( + constant_op.constant(values), constant_op.constant(indices), None) + # Single fetch, use as tuple + ind_out = s.run(ind) + values_out, indices_out, dense_shape_out = ind_out + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # Single fetch, use as IndexedSlicesValue + ind_out = s.run(ind) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + # Tuple fetch, use as tuple + values_out, indices_out, dense_shape_out = s.run(ind) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as tuple + (values_out, indices_out, dense_shape_out), = s.run([ind]) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + self.assertAllEqual(dense_shape_out, dense_shape) + # List fetch, use as IndexedSlicesValue + ind_out, = s.run([ind]) + self.assertAllEqual(ind_out.values, values) + self.assertAllEqual(ind_out.indices, indices) + self.assertAllEqual(ind_out.dense_shape, dense_shape) + + def testFeedIndexedSlicesWithoutDenseShape(self): + with session.Session() as s: + values = np.array([1.0, 2.0]).astype(np.float32) + indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) + dense_shape = None + ind = ops.IndexedSlices( + array_ops.placeholder(dtype=np.float32, + shape=(2,)), + array_ops.placeholder(dtype=np.int64, + shape=(2, 3)), + None) + ind_values = array_ops.identity(ind.values) + ind_indices = array_ops.identity(ind.indices) + ind2 = ops.IndexedSlices(ind_values, ind_indices) + # Feed with tuple + values_out, indices_out = s.run( + [ind_values, ind_indices], {ind: (values, indices)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Feed with IndexedSlicesValue + values_out, indices_out = s.run( + [ind_values, ind_indices], + {ind: ops.IndexedSlicesValue(values, indices, dense_shape)}) + self.assertAllEqual(values_out, values) + self.assertAllEqual(indices_out, indices) + # Feed with IndexedSlicesValue, fetch IndexedSlicesValue + ind2_out = s.run(ind2, {ind: ops.IndexedSlicesValue(values, indices, + dense_shape)}) + self.assertAllEqual(ind2_out.values, values) + self.assertAllEqual(ind2_out.indices, indices) + self.assertAllEqual(ind2_out.dense_shape, dense_shape) + def testExtendWithStatelessOperations(self): with session.Session() as s: a = constant_op.constant(1.0, shape=[1, 2]) |