aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-08 10:25:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 11:32:57 -0700
commit9205b55c6bbef400fa1cdb0140a99576608f5b3f (patch)
treed1511fd7822c7c5611c2d6d3ab0dd3f4c8ff8207 /tensorflow/python
parent71e3186fd3b3b62aeb43a697432565a9434fa9f5 (diff)
Switch several ops in array_ops.py to use C++ shape functions.
Change C++ shape function for ExpandDims to be more permissive - it now allows 'dim' to be any tensor with 1 element, although that is not currently converted to use C++ because of a separate issue to fix first (later change). Change C++ shape functions for SpaceToBatch and BatchToSpace to output rank-4 unknown shapes. Change: 132578764
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/framework/common_shapes.py2
-rw-r--r--tensorflow/python/kernel_tests/spacetobatch_op_test.py8
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py2
-rw-r--r--tensorflow/python/ops/array_ops.py255
4 files changed, 14 insertions, 253 deletions
diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py
index ba366cbc13..8c5251bdf1 100644
--- a/tensorflow/python/framework/common_shapes.py
+++ b/tensorflow/python/framework/common_shapes.py
@@ -646,7 +646,7 @@ def call_cpp_shape_fn(op, input_tensors_needed=None,
if str(result) != str(python_result):
raise ValueError(
("Python vs CPP shape mismatch. "
- "python: %s vs CPP: %s on node %s "
+ "CPP: %s vs python: %s on node %s "
"with input shapes %s") % (
str(result), str(python_result), str(op.node_def),
",".join([str(i.get_shape()) for i in op.inputs])))
diff --git a/tensorflow/python/kernel_tests/spacetobatch_op_test.py b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
index f3ff2d517a..b340394017 100644
--- a/tensorflow/python/kernel_tests/spacetobatch_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetobatch_op_test.py
@@ -166,7 +166,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
x_np = [[[[1], [2]], [[3], [4]]]]
paddings = np.zeros((2, 2), dtype=np.int32)
block_size = 10
- with self.assertRaises(IndexError):
+ with self.assertRaises(ValueError):
out_tf = tf.space_to_batch(x_np, paddings, block_size)
out_tf.eval()
@@ -175,7 +175,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
x_np = [[[[1], [2], [3]], [[3], [4], [7]]]]
paddings = np.zeros((2, 2), dtype=np.int32)
block_size = 3
- with self.assertRaises(IndexError):
+ with self.assertRaises(ValueError):
_ = tf.space_to_batch(x_np, paddings, block_size)
def testBlockSizeNotDivisibleHeight(self):
@@ -183,7 +183,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
x_np = [[[[1], [2]], [[3], [4]], [[5], [6]]]]
paddings = np.zeros((2, 2), dtype=np.int32)
block_size = 3
- with self.assertRaises(IndexError):
+ with self.assertRaises(ValueError):
_ = tf.space_to_batch(x_np, paddings, block_size)
def testBlockSizeNotDivisibleBoth(self):
@@ -191,7 +191,7 @@ class SpaceToBatchErrorHandlingTest(tf.test.TestCase):
x_np = [[[[1], [2]], [[3], [4]]]]
paddings = np.zeros((2, 2), dtype=np.int32)
block_size = 3
- with self.assertRaises(IndexError):
+ with self.assertRaises(ValueError):
_ = tf.space_to_batch(x_np, paddings, block_size)
def testUnknownShape(self):
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index ec8a41f59c..5bc3f5358a 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -226,7 +226,7 @@ class TransposeTest(tf.test.TestCase):
self._testError(np.arange(0., 2 ** 11).reshape([2] * 11),
np.arange(11),
"not implemented")
- with self.assertRaises(IndexError):
+ with self.assertRaises(ValueError):
tf.transpose(np.arange(0., 30).reshape([2, 3, 5]), [0, 1, 3])
self._testError(np.arange(0., 30).reshape([2, 3, 5]),
[0, 1, 1],
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 54e2298e35..9141b873fd 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -807,39 +807,7 @@ ops.RegisterShape("Unpack")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("Concat")
def _ConcatShape(op):
- concat_dim = tensor_util.constant_value(op.inputs[0])
- if concat_dim is None:
- # Return an unknown shape with the same rank as the inputs, or an
- # unknown rank if no input's rank is known.
- rank = None
- for value in op.inputs[1:]:
- if rank is not None:
- value.get_shape().assert_has_rank(rank)
- else:
- rank = value.get_shape().ndims
- if rank == 0:
- raise ValueError("Can't concatenate scalars (use tf.pack instead)")
- return [tensor_shape.unknown_shape(ndims=rank)]
-
- else:
- # Merge all the non-concat dims, and sum the concat dim to make an
- # output shape.
- concat_dim = int(concat_dim)
- if concat_dim < 0:
- raise ValueError("Expected concat_dim >= 0, but got %d" % concat_dim)
-
- output_shape = op.inputs[1].get_shape()
- for value in op.inputs[2:]:
- value_shape = value.get_shape()
- if value_shape.ndims is not None and concat_dim >= value_shape.ndims:
- raise ValueError("Expected concat_dim in range [0, %d), but got %d" %
- (value_shape.ndims, concat_dim))
- before = output_shape[:concat_dim].merge_with(value_shape[:concat_dim])
- at = output_shape[concat_dim] + value_shape[concat_dim]
- after = output_shape[
- concat_dim + 1:].merge_with(value_shape[concat_dim + 1:])
- output_shape = before.concatenate(at).concatenate(after)
- return [output_shape]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
ops.RegisterShape("ConcatOffset")(common_shapes.call_cpp_shape_fn)
@@ -1834,63 +1802,12 @@ ops.RegisterShape("ListDiff")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("Pad")
@ops.RegisterShape("MirrorPad")
def _PadShape(op):
- """Shape function for the Pad op.
-
- This op has two inputs:
-
- * input: A rank-N tensor.
- * paddings: An N-by-2 matrix, in which the i^th row contains the
- number of padding elements to add before and after `input` in the
- i^th dimension.
-
- It has one output, which has the same rank as input, and additional
- elements according to the values in paddings.
-
- Args:
- op: A Pad Operation.
-
- Returns:
- A single-element list containing the shape of the output.
-
- Raises:
- ValueError: If the input shapes are incompatible.
- """
- paddings_shape = op.inputs[1].get_shape().with_rank(2)
- input_shape = op.inputs[0].get_shape()
- input_shape = input_shape.with_rank(paddings_shape[0].value)
- paddings_shape = paddings_shape.merge_with(
- tensor_shape.matrix(input_shape.ndims, 2))
- paddings = tensor_util.constant_value(op.inputs[1])
- if paddings is None:
- return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
- else:
- output_dims = []
- for i, dim in enumerate(input_shape.dims):
- if paddings[i, 0] < 0 or paddings[i, 1] < 0:
- raise ValueError("paddings must be non-negative")
- output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
- return [tensor_shape.TensorShape(output_dims)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("MirrorPadGrad")
def _MirrorPadGradShape(op):
- """Shape function for the MirrorPadGrad op."""
- paddings_shape = op.inputs[1].get_shape().with_rank(2)
- input_shape = op.inputs[0].get_shape().with_rank(paddings_shape[0].value)
- paddings_shape = paddings_shape.merge_with(tensor_shape.matrix(
- input_shape.ndims, 2))
- paddings = tensor_util.constant_value(op.inputs[1])
- if paddings is None:
- return [tensor_shape.unknown_shape(ndims=input_shape.ndims)]
-
- output_dims = []
- for i, dim in enumerate(input_shape.dims):
- if paddings[i, 0] < 0 or paddings[i, 1] < 0:
- raise ValueError("Paddings must be non-negative.")
- if dim < paddings[i, 0] + paddings[i, 1]:
- raise ValueError("Output dimension is negative.")
- output_dims.append(dim - paddings[i, 0] - paddings[i, 1])
- return [tensor_shape.TensorShape(output_dims)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
ops.RegisterShape("ReverseSequence")(common_shapes.call_cpp_shape_fn)
@@ -1900,58 +1817,12 @@ ops.RegisterShape("ShapeN")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("Transpose")
def _TransposeShape(op):
- """Shape function for the Transpose op.
-
- This op takes two inputs:
-
- * input: a rank-N tensor of arbitrary shape.
- * shuffle: a length-N vector.
-
- Its output is the rank-N tensor computed by permuting the dimensions
- of input according to shuffle.
-
- Args:
- op: A Transpose op.
-
- Returns:
- A single-element list containing the shape of the output.
-
- Raises:
- ValueError: If the shapes of input and shuffle are incompatible.
- IndexError: If shuffle contains an index that is >= the rank of input.
- """
- input_shape = op.inputs[0].get_shape()
- transpose_shape = op.inputs[1].get_shape().merge_with(tensor_shape.vector(
- input_shape.ndims))
- transpose_vec = tensor_util.constant_value(op.inputs[1])
- if transpose_vec is None:
- return [tensor_shape.unknown_shape(ndims=transpose_shape[0].value)]
- else:
- return [tensor_shape.TensorShape([input_shape[i]
- for i in transpose_vec.tolist()])]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("Split")
def _SplitShape(op):
- """Shape function for the Split op."""
- split_dim = tensor_util.constant_value(op.inputs[0])
- num_split = len(op.outputs)
- input_shape = op.inputs[1].get_shape()
- if split_dim is None:
- return [tensor_shape.unknown_shape(ndims=input_shape.ndims)] * num_split
- else:
- split_dim = int(split_dim)
- input_shape = input_shape.with_rank_at_least(split_dim + 1)
- if not (input_shape[split_dim] % num_split).is_compatible_with(0):
- raise ValueError(
- "Number of ways to split should evenly divide the split "
- "dimension but got split_dim %d (size = %d) and num_split %d" %
- (split_dim, input_shape[split_dim].value, num_split))
- prefix = input_shape[:split_dim]
- size_in_split_dim = input_shape[split_dim] // num_split
- suffix = input_shape[split_dim + 1:]
- output_shape = prefix.concatenate(size_in_split_dim).concatenate(suffix)
- return [output_shape] * num_split
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("Tile")
@@ -2088,18 +1959,7 @@ def edit_distance(hypothesis, truth, normalize=True, name="edit_distance"):
@ops.RegisterShape("EditDistance")
def _EditDistanceShape(op):
- """Shape function for the EditDistance op."""
- hypothesis_shape = tensor_util.constant_value(op.inputs[2])
- truth_shape = tensor_util.constant_value(op.inputs[5])
- if hypothesis_shape is not None and truth_shape is not None:
- if len(hypothesis_shape) != len(truth_shape):
- raise ValueError(
- "Inconsistent ranks in hypothesis and truth. Saw shapes: %s and %s" %
- (str(hypothesis_shape), str(truth_shape)))
- return [tensor_shape.TensorShape(
- [max(h, t) for h, t in zip(hypothesis_shape[:-1], truth_shape[:-1])])]
-
- return [tensor_shape.unknown_shape()]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[2, 5])
# The remaining ops do not change the shape of their inputs.
@@ -2164,80 +2024,7 @@ def _ExtractImagePatchesShape(op):
@ops.RegisterShape("SpaceToBatch")
def _SpaceToBatchShape(op):
- """Shape function for the SpaceToBatch op.
-
- The output shape is determined by the following inputs/ attributes:
-
- * input: A rank-4 tensor with shape [B, H, W, D]
- * paddings: A 2-by-2 matrix, specified as follows:
-
- paddings = [[pad_top, pad_bottom], [pad_left, pad_right]],
-
- implying effective padded spatial dimensions:
-
- Hp = pad_top + H + pad_bottom
- Wp = pad_left + W + pad_right
-
- Both Hp and Wp must be multiples of block_size.
- * block_size: an int.
-
- Its output is also a rank-4 tensor with shape:
-
- [B*block_size*block_size, Hp/block_size, Wp/block_size, D]
-
- Args:
- op: A SpaceToBatch op.
-
- Returns:
- A single-element list containing the shape of the output.
-
- Raises:
- ValueError: If the shapes of inputs are not as expected.
- IndexError: If block_size does not divide Wp or Hp.
- """
- # Check that the input tensor is 4-D.
- try:
- input_shape = op.inputs[0].get_shape().with_rank(4)
- except ValueError:
- raise ValueError(
- "tf.space_to_batch() requires 4-D input tensor.")
-
- # Check that the paddings tensor is a matrix with shape [2, 2].
- try:
- paddings_shape = op.inputs[1].get_shape().with_rank(2)
- except ValueError:
- raise ValueError(
- "tf.space_to_batch() requires 2-D paddings tensor.")
-
- if paddings_shape[0] != 2 or paddings_shape[1] != 2:
- raise ValueError(
- "tf.space_to_batch() requires input paddings with shape [2, 2].")
-
- block_size = op.get_attr("block_size")
- if block_size <= 1:
- raise ValueError("Attribute block_size has to be > 1.")
-
- paddings = tensor_util.constant_value(op.inputs[1])
- if paddings is not None:
- if (paddings[0, 0] < 0 or paddings[0, 1] < 0 or
- paddings[1, 0] < 0 or paddings[1, 1] < 0):
- raise ValueError("paddings cannot be negative.")
-
- input_height = input_shape[1] + paddings[0, 0] + paddings[0, 1]
- input_width = input_shape[2] + paddings[1, 0] + paddings[1, 1]
-
- if input_height % block_size > 0 or input_width % block_size > 0:
- raise IndexError("block_size needs to divide both width and height.")
- else:
- input_height = tensor_shape.Dimension(None)
- input_width = tensor_shape.Dimension(None)
-
- batch = input_shape[0] * block_size * block_size
- height = input_height // block_size
- width = input_width // block_size
- depth = input_shape[3]
-
- return [tensor_shape.TensorShape([batch, height, width, depth])]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("BatchToSpace")
@@ -2584,33 +2371,7 @@ def one_hot(indices, depth, on_value=None, off_value=None,
@ops.RegisterShape("OneHot")
def _OneHotShape(op):
- """Shape function for the OneHot op.
-
- It closely follows the code in the .cc implementation.
-
- Args:
- op: A OneHot Operation.
-
- Returns:
- A single-element list containing the shape of the output.
-
- Raises:
- ValueError: if axis < -1.
- """
- indices_shape = op.inputs[0].get_shape()
- indices_dims = indices_shape.ndims
- depth = tensor_util.constant_value(op.inputs[1])
- axis = op.get_attr("axis")
-
- if axis < -1:
- raise ValueError("axis must be >= -1")
-
- new_shape = None
- if indices_dims is not None:
- new_shape = indices_shape.as_list()
- new_shape.insert(axis % (indices_dims + 1), depth)
-
- return [tensor_shape.TensorShape(new_shape)]
+ return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("PlaceholderWithDefault")