aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/shape_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/shape_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index a9fc699b21..7368251ab6 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -258,6 +258,16 @@ class ShapeOpsTest(test.TestCase):
self.assertAllEqual([True], array_ops.expand_dims(inp, 0).eval())
self.assertAllEqual([True], array_ops.expand_dims(inp, -1).eval())
+ def testExpandDimsDimType(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ x = np.zeros([2])
+ np_ans = np.expand_dims(x, axis=0)
+ with self.test_session(use_gpu=True):
+ tensor = array_ops.expand_dims(x, constant_op.constant(0, dtype))
+ tf_ans = tensor.eval()
+ self.assertShapeEqual(np_ans, tensor)
+ self.assertAllEqual(np_ans, tf_ans)
+
def _compareSqueeze(self, x, squeeze_dims, use_gpu):
with self.test_session(use_gpu=use_gpu):
if squeeze_dims: