aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-10-14 11:10:11 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 12:22:02 -0700
commit18ead0a3b6e223e082a5efdb6b769a3853cbdbc9 (patch)
treec04344e8b930d977afd50130f857f85f7c9212bb
parentd3ea52425bb52ebe2b3190a5f831f16c38505cdd (diff)
Add test for nested values passed to convert_to_tensor.
Change: 136183184
-rw-r--r--tensorflow/python/framework/ops_test.py51
1 files changed, 41 insertions, 10 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index 34c6b326b4..086f4e0d71 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -31,10 +31,10 @@ from tensorflow.python.framework import test_util
from tensorflow.python.framework import versions
# Import gradients to register _IndexedSlicesToTensor.
from tensorflow.python.ops import control_flow_ops
-import tensorflow.python.ops.gradients # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+import tensorflow.python.ops.gradients # pylint: disable=unused-import
from tensorflow.python.platform import googletest
from tensorflow.python.util import compat
@@ -270,19 +270,50 @@ class OperationTest(test_util.TensorFlowTestCase):
ops.Operation(ops._NodeDef("op", "invalid:0"), g)
def testShapeFunctionAbsence(self):
- def _test():
- pass
g = ops.Graph()
with self.assertRaises(RuntimeError):
g.create_op("shapeless_op", [], [dtypes.float32])
def testNoShapeFunction(self):
g = ops.Graph()
- op = ops.Operation(ops._NodeDef("op", "an_op"), g,
- output_types = [dtypes.float32])
+ ops.Operation(ops._NodeDef("op", "an_op"), g,
+ output_types=[dtypes.float32])
self.assertEqual(tensor_shape.unknown_shape(),
_apply_op(g, "an_op", [], [dtypes.float32]).get_shape())
+ def testConvertToTensorNestedArray(self):
+ with self.test_session():
+ values = [[2], [3], [5], [7]]
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, tensor.eval())
+
+ def testConvertToTensorNestedTuple(self):
+ with self.test_session():
+ values = ((2,), (3,), (5,), (7,))
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
+
+ def testConvertToTensorNestedTensors(self):
+ with self.test_session():
+ values = ((2,), (3,), (5,), (7,))
+ tensor = ops.convert_to_tensor(
+ [constant_op.constant(row) for row in values])
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, tensor.eval())
+ tensor = ops.convert_to_tensor(
+ [[constant_op.constant(v) for v in row] for row in values])
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(values, tensor.eval())
+
+ def testConvertToTensorNestedMix(self):
+ with self.test_session():
+ values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
+ tensor = ops.convert_to_tensor(values)
+ self.assertAllEqual((4, 1), tensor.get_shape().as_list())
+ self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
+
def testConvertToTensorPreferred(self):
with self.test_session():
values = [2, 3, 5, 7]
@@ -290,7 +321,7 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(dtypes.float32, tensor.dtype)
with self.test_session():
- # Convert empty tensor to anything
+ # Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
@@ -309,7 +340,7 @@ class OperationTest(test_util.TensorFlowTestCase):
_ = ops.convert_to_tensor(values, dtype=dtypes.int64)
def testNoConvert(self):
- # Operation cannot be converted to Tensor
+ # Operation cannot be converted to Tensor.
op = control_flow_ops.no_op()
with self.assertRaisesRegexp(TypeError,
r"Can't convert Operation '.*' to Tensor"):
@@ -919,13 +950,13 @@ def copy_op(x):
@ops.RegisterGradient("copy")
-def _CopyGrad(op, x_grad):
+def _CopyGrad(op, x_grad): # pylint: disable=invalid-name
_ = op
return x_grad
@ops.RegisterGradient("copy_override")
-def _CopyOverrideGrad(op, x_grad):
+def _CopyOverrideGrad(op, x_grad): # pylint: disable=invalid-name
_ = op
return x_grad
@@ -953,7 +984,7 @@ class RegistrationTest(test_util.TensorFlowTestCase):
with g.gradient_override_map({"copy": "unknown_override"}):
y = copy_op(x)
with self.assertRaisesRegexp(LookupError, "unknown_override"):
- fn = ops.get_gradient_function(y.op)
+ ops.get_gradient_function(y.op)
class ComparisonTest(test_util.TensorFlowTestCase):