aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/framework
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 17:31:22 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-18 18:19:22 +0800
commit42c116791beddd071e669f6455b8bd3f55cc1bcc (patch)
tree6399cf5e8b8d2fcc28dcda8879809d907628603a /tensorflow/contrib/framework
parent46522659b41d7c2fe93fec54feb93a8e4b56505d (diff)
TST: add test case for with_shape
Diffstat (limited to 'tensorflow/contrib/framework')
-rw-r--r--tensorflow/contrib/framework/python/framework/tensor_util_test.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
index af1b404cb5..2fa1d33328 100644
--- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py
@@ -29,7 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_lib
@@ -185,6 +185,16 @@ class WithShapeTest(test.TestCase):
shape,
unexpected_shapes)
+ def test_with_shape_2x2_with_partial_expected_shape(self):
+ with self.test_session():
+ value = [[42, 43], [44, 45]]
+ actual_shape = [2, 2]
+ tensor = constant_op.constant(value, shape=actual_shape)
+ partial_expected_shape = tensor_shape.TensorShape([None, 2])
+ # Won't raise any exception here:
+ tensor_with_shape = tensor_util.with_shape(partial_expected_shape, tensor)
+ np.testing.assert_array_equal(value, tensor_with_shape.eval())
+
def test_with_shape_none(self):
with self.test_session():
tensor_no_shape = array_ops.placeholder(dtypes.float32)