diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/shape_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/shape_ops_test.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index a9fc699b21..7368251ab6 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -258,6 +258,16 @@ class ShapeOpsTest(test.TestCase): self.assertAllEqual([True], array_ops.expand_dims(inp, 0).eval()) self.assertAllEqual([True], array_ops.expand_dims(inp, -1).eval()) + def testExpandDimsDimType(self): + for dtype in [dtypes.int32, dtypes.int64]: + x = np.zeros([2]) + np_ans = np.expand_dims(x, axis=0) + with self.test_session(use_gpu=True): + tensor = array_ops.expand_dims(x, constant_op.constant(0, dtype)) + tf_ans = tensor.eval() + self.assertShapeEqual(np_ans, tensor) + self.assertAllEqual(np_ans, tf_ans) + def _compareSqueeze(self, x, squeeze_dims, use_gpu): with self.test_session(use_gpu=use_gpu): if squeeze_dims: |