aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py7
-rw-r--r--tensorflow/python/ops/functional_ops.py2
-rw-r--r--tensorflow/python/ops/gradients_impl.py1
-rw-r--r--tensorflow/python/ops/image_ops.py4
-rw-r--r--tensorflow/python/ops/image_ops_impl.py106
-rw-r--r--tensorflow/python/ops/image_ops_test.py112
-rw-r--r--tensorflow/python/ops/linalg_grad.py59
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py7
-rw-r--r--tensorflow/python/ops/manip_grad.py31
-rw-r--r--tensorflow/python/ops/manip_ops.py38
-rw-r--r--tensorflow/python/ops/math_ops.py10
-rw-r--r--tensorflow/python/ops/rnn.py6
-rw-r--r--tensorflow/python/ops/standard_ops.py73
13 files changed, 379 insertions, 77 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index e3902f5a8a..ad409ad7e5 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -35,6 +35,7 @@ See the @{$python/array_ops} guide.
@@reshape
@@squeeze
@@expand_dims
+@@unravel_index
@@meshgrid
@@slice
@@strided_slice
@@ -1589,9 +1590,9 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True):
Args:
tensor: A `Tensor`.
- dtype: A type for the returned `Tensor`. Must be `float32`, `float64`,
- `int8`, `uint8`, `int16`, `uint16`, int32`, `int64`,
- `complex64`, `complex128` or `bool`.
+ dtype: A type for the returned `Tensor`. Must be `float16`, `float32`,
+ `float64`, `int8`, `uint8`, `int16`, `uint16`, `int32`, `int64`,
+ `complex64`, `complex128`, `bool` or `string`.
name: A name for the operation (optional).
optimize: if true, attempt to statically determine the shape of 'tensor'
and encode it as a constant.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 7dbccf1caf..ac03d30fcd 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -458,7 +458,7 @@ def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
`[i1, i2]` then an appropriate signature for `fn` in `python2` is:
- `fn = lambda (acc_p1, acc_p2), (t1 [t2, t3]):` and `fn` must return a list,
+ `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
`[acc_n1, acc_n2]`. An alternative correct signature for `fn`, and the
one that works in `python3`, is:
`fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 28b26a09a5..9f06c0ee1f 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -44,6 +44,7 @@ from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
from tensorflow.python.ops import logging_ops # pylint: disable=unused-import
+from tensorflow.python.ops import manip_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 3b0b5a978c..de12c5f63f 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -49,6 +49,10 @@ See the @{$python/image} guide.
@@grayscale_to_rgb
@@hsv_to_rgb
@@rgb_to_hsv
+@@rgb_to_yiq
+@@yiq_to_rgb
+@@rgb_to_yuv
+@@yuv_to_rgb
@@convert_image_dtype
@@adjust_brightness
@@random_brightness
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 2c231ef56c..14a38f25d1 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -1508,7 +1508,7 @@ def sample_distorted_bounding_box(image_size,
bounding_boxes,
seed=None,
seed2=None,
- min_object_covered=None,
+ min_object_covered=0.1,
aspect_ratio_range=None,
area_range=None,
max_attempts=None,
@@ -1669,3 +1669,107 @@ def non_max_suppression(boxes,
return gen_image_ops._non_max_suppression_v2(boxes, scores, max_output_size,
iou_threshold)
# pylint: enable=protected-access
+
+
+_rgb_to_yiq_kernel = [[0.299, 0.59590059,
+ 0.2115], [0.587, -0.27455667, -0.52273617],
+ [0.114, -0.32134392, 0.31119955]]
+
+
+def rgb_to_yiq(images):
+ """Converts one or more images from RGB to YIQ.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YIQ
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _rgb_to_yiq_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_yiq_to_rgb_kernel = [[1, 1, 1], [0.95598634, -0.27201283, -1.10674021],
+ [0.6208248, -0.64720424, 1.70423049]]
+
+
+def yiq_to_rgb(images):
+ """Converts one or more images from YIQ to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ I value are in [-0.5957,0.5957] and Q value are in [-0.5226,0.5226].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _yiq_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_rgb_to_yuv_kernel = [[0.299, -0.14714119,
+ 0.61497538], [0.587, -0.28886916, -0.51496512],
+ [0.114, 0.43601035, -0.10001026]]
+
+
+def rgb_to_yuv(images):
+ """Converts one or more images from RGB to YUV.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the YUV
+ value of the pixels.
+ The output is only well defined if the value in images are in [0,1].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _rgb_to_yuv_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
+
+
+_yuv_to_rgb_kernel = [[1, 1, 1], [0, -0.394642334, 2.03206185],
+ [1.13988303, -0.58062185, 0]]
+
+
+def yuv_to_rgb(images):
+ """Converts one or more images from YUV to RGB.
+
+ Outputs a tensor of the same shape as the `images` tensor, containing the RGB
+ value of the pixels.
+ The output is only well defined if the Y value in images are in [0,1],
+ U and V value are in [-0.5,0.5].
+
+ Args:
+ images: 2-D or higher rank. Image data to convert. Last dimension must be
+ size 3.
+
+ Returns:
+ images: tensor with the same shape as `images`.
+ """
+ images = ops.convert_to_tensor(images, name='images')
+ kernel = ops.convert_to_tensor(
+ _yuv_to_rgb_kernel, dtype=images.dtype, name='kernel')
+ ndims = images.get_shape().ndims
+ return math_ops.tensordot(images, kernel, axes=[[ndims - 1], [0]])
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 47dd8231c0..b12bd3d5b0 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -85,6 +85,64 @@ class RGBToHSVTest(test_util.TensorFlowTestCase):
self.assertAllClose(rgb_tf, rgb_np)
+class RGBToYIQTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YIQ and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yiq(batch0)
+ batch2 = image_ops.yiq_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yiq, split0))
+ split2 = list(map(image_ops.yiq_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
+class RGBToYUVTest(test_util.TensorFlowTestCase):
+
+ def testBatch(self):
+ # Build an arbitrary RGB image
+ np.random.seed(7)
+ batch_size = 5
+ shape = (batch_size, 2, 7, 3)
+
+ for nptype in [np.float32, np.float64]:
+ inp = np.random.rand(*shape).astype(nptype)
+
+ # Convert to YUV and back, as a batch and individually
+ with self.test_session(use_gpu=True) as sess:
+ batch0 = constant_op.constant(inp)
+ batch1 = image_ops.rgb_to_yuv(batch0)
+ batch2 = image_ops.yuv_to_rgb(batch1)
+ split0 = array_ops.unstack(batch0)
+ split1 = list(map(image_ops.rgb_to_yuv, split0))
+ split2 = list(map(image_ops.yuv_to_rgb, split1))
+ join1 = array_ops.stack(split1)
+ join2 = array_ops.stack(split2)
+ batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2])
+
+ # Verify that processing batch elements together is the same as separate
+ self.assertAllClose(batch1, join1, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, join2, rtol=1e-4, atol=1e-4)
+ self.assertAllClose(batch2, inp, rtol=1e-4, atol=1e-4)
+
+
class GrayscaleToRGBTest(test_util.TensorFlowTestCase):
def _RGBToGrayscale(self, images):
@@ -1839,6 +1897,26 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase):
self.assertAllEqual([3], end.get_shape().as_list())
self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+ def testDefaultMinObjectCovered(self):
+ # By default min_object_covered=0.1 if not provided
+ with self.test_session(use_gpu=True):
+ image_size = constant_op.constant(
+ [40, 50, 1], shape=[3], dtype=dtypes.int32)
+ bounding_box = constant_op.constant(
+ [0.0, 0.0, 1.0, 1.0],
+ shape=[4],
+ dtype=dtypes.float32,
+ )
+ begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box(
+ image_size=image_size,
+ bounding_boxes=bounding_box,
+ aspect_ratio_range=(0.75, 1.33),
+ area_range=(0.05, 1.0))
+
+ self.assertAllEqual([3], begin.get_shape().as_list())
+ self.assertAllEqual([3], end.get_shape().as_list())
+ self.assertAllEqual([1, 1, 4], bbox_for_drawing.get_shape().as_list())
+
class ResizeImagesTest(test_util.TensorFlowTestCase):
@@ -3092,6 +3170,40 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase):
boxes, scores, max_output_size, iou_threshold).eval()
self.assertAllClose(selected_indices, [3, 0, 5])
+ def testInvalidShape(self):
+ # The boxes should be 2D of shape [num_boxes, 4].
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 2 but is rank 1"):
+ boxes = constant_op.constant([0.0, 0.0, 1.0, 1.0])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ with self.assertRaisesRegexp(ValueError, "Dimension must be 4 but is 3"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ # The scores should be 1D of shape [num_boxes].
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 1 but is rank 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([[0.9]])
+ image_ops.non_max_suppression(boxes, scores, 3, 0.5)
+
+ # The max_output_size should be a scaler (0-D).
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 0 but is rank 1"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, [3], 0.5)
+
+ # The iou_threshold should be a scaler (0-D).
+ with self.assertRaisesRegexp(ValueError,
+ "Shape must be rank 0 but is rank 2"):
+ boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]])
+ scores = constant_op.constant([0.9])
+ image_ops.non_max_suppression(boxes, scores, 3, [[0.5]])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index 13a32c83d9..3cbbf3412a 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -277,20 +277,28 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
# https://j-towns.github.io/papers/svd-derivative.pdf
a = op.inputs[0]
a_shape = a.get_shape().with_rank_at_least(2)
+ grad_s_mat = array_ops.matrix_diag(grad_s)
- if op.get_attr("compute_uv"):
- # TODO(rmlarsen): Make this work with complex types.
- if a.dtype.is_complex:
- raise NotImplementedError(
- "SVD gradient is not implemented for complex types and "
- "compute_uv=True.")
- grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
- grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
- m = a_shape[-2].merge_with(grad_u_shape[-2])
- n = a_shape[-1].merge_with(grad_v_shape[-2])
- batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
- grad_v_shape[:-2])
- a_shape = batch_shape.concatenate([m, n])
+ if not op.get_attr("compute_uv"):
+ s, u, v = linalg_ops.svd(a, compute_uv=True)
+ grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True))
+ grad_a.set_shape(a_shape)
+ return grad_a
+
+ full_matrices = op.get_attr("full_matrices")
+
+ # TODO(rmlarsen): Make this work with complex types.
+ if a.dtype.is_complex:
+ raise NotImplementedError(
+ "SVD gradient is not implemented for complex types and "
+ "compute_uv=True.")
+ grad_u_shape = grad_u.get_shape().with_rank_at_least(2)
+ grad_v_shape = grad_v.get_shape().with_rank_at_least(2)
+ m = a_shape[-2].merge_with(grad_u_shape[-2])
+ n = a_shape[-1].merge_with(grad_v_shape[-2])
+ batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with(
+ grad_v_shape[:-2])
+ a_shape = batch_shape.concatenate([m, n])
m = a_shape[-2].value
n = a_shape[-1].value
@@ -300,12 +308,9 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
"SVD gradient has not been implemented for input with unknown "
"inner matrix shape.")
- if not op.get_attr("compute_uv"):
- s, u, v = linalg_ops.svd(a, compute_uv=True, full_matrices=True)
- else:
- s = op.outputs[0]
- u = op.outputs[1]
- v = op.outputs[2]
+ s = op.outputs[0]
+ u = op.outputs[1]
+ v = op.outputs[2]
use_adjoint = False
if m > n:
@@ -317,19 +322,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
grad_u, grad_v = grad_v, grad_u
with ops.control_dependencies([grad_s, grad_u, grad_v]):
- grad_s_mat = array_ops.matrix_diag(grad_s)
- if not op.get_attr("compute_uv"):
- if use_adjoint:
- grad_a = math_ops.matmul(
- v[..., :, :m], math_ops.matmul(u, grad_s_mat), adjoint_b=True)
- else:
- grad_a = math_ops.matmul(u,
- math_ops.matmul(
- grad_s_mat, v[..., :, :m], adjoint_b=True))
- grad_a.set_shape(a_shape)
- return grad_a
-
- if op.get_attr("full_matrices") and abs(m - n) > 1:
+ if full_matrices and abs(m - n) > 1:
raise NotImplementedError(
"svd gradient is not implemented for abs(m - n) > 1 "
"when full_matrices is True")
@@ -371,7 +364,7 @@ def _SvdGrad(op, grad_s, grad_u, grad_v):
gv1t_v1 = math_ops.matmul(gv1t, v1)
term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True)
- if op.get_attr("full_matrices"):
+ if full_matrices:
v2 = v[..., :, m:n]
grad_v2 = grad_v[..., :, m:n]
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index e75a9b22e4..84afbf0627 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -547,12 +547,13 @@ def mean_pairwise_squared_error(
num_present_per_batch = _num_present(diffs, weights, per_batch=True)
term1 = 2.0 * _safe_div(sum_squares_diff_per_batch,
- num_present_per_batch)
+ num_present_per_batch - 1)
sum_diff = math_ops.reduce_sum(
diffs, reduction_indices=reduction_indices, keep_dims=True)
- term2 = 2.0 * _safe_div(math_ops.square(sum_diff),
- math_ops.square(num_present_per_batch))
+ term2 = 2.0 * _safe_div(
+ math_ops.square(sum_diff),
+ math_ops.multiply(num_present_per_batch, num_present_per_batch - 1))
weighted_losses = math_ops.multiply(term1 - term2, weights)
loss = math_ops.reduce_sum(weighted_losses)
diff --git a/tensorflow/python/ops/manip_grad.py b/tensorflow/python/ops/manip_grad.py
new file mode 100644
index 0000000000..bb2069359d
--- /dev/null
+++ b/tensorflow/python/ops/manip_grad.py
@@ -0,0 +1,31 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gradients for operators defined in manip_ops.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import manip_ops
+
+
+@ops.RegisterGradient("Roll")
+def _RollGrad(op, grad):
+ # The gradient is just the roll reversed
+ shift = op.inputs[1]
+ axis = op.inputs[2]
+ roll_grad = manip_ops.roll(grad, -shift, axis)
+ return roll_grad, None, None
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
new file mode 100644
index 0000000000..91e15b47b9
--- /dev/null
+++ b/tensorflow/python/ops/manip_ops.py
@@ -0,0 +1,38 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operators for manipulating tensors.
+
+@@roll
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util.all_util import remove_undocumented
+
+
+# pylint: disable=protected-access
+def roll(input, shift, axis): # pylint: disable=redefined-builtin
+ return _gen_manip_ops.roll(input, shift, axis)
+
+
+roll.__doc__ = _gen_manip_ops.roll.__doc__
+# pylint: enable=protected-access
+
+_allowed_symbols = ['roll']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 827e3caa36..9a8ac93de9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2826,10 +2826,14 @@ def tensordot(a, b, axes, name=None):
"""Generates two sets of contraction axes for the two tensor arguments."""
a_shape = a.get_shape()
if isinstance(axes, compat.integral_types):
- if axes < 1:
- raise ValueError("'axes' must be at least 1.")
+ if axes < 0:
+ raise ValueError("'axes' must be at least 0.")
if a_shape.ndims is not None:
- return range(a_shape.ndims - axes, a_shape.ndims), range(axes)
+ if axes > a_shape.ndims:
+ raise ValueError("'axes' must not be larger than the number of "
+ "dimensions of tensor %s." % a)
+ return (list(xrange(a_shape.ndims - axes, a_shape.ndims)),
+ list(xrange(axes)))
else:
rank = array_ops.rank(a)
return (range(rank - axes, rank, dtype=dtypes.int32),
diff --git a/tensorflow/python/ops/rnn.py b/tensorflow/python/ops/rnn.py
index 24c6f64f0a..da80e72071 100644
--- a/tensorflow/python/ops/rnn.py
+++ b/tensorflow/python/ops/rnn.py
@@ -1127,6 +1127,12 @@ def raw_rnn(cell, loop_fn,
def _copy_some_through(current, candidate):
"""Copy some tensors through via array_ops.where."""
def copy_fn(cur_i, cand_i):
+ # TensorArray and scalar get passed through.
+ if isinstance(cur_i, tensor_array_ops.TensorArray):
+ return cand_i
+ if cur_i.shape.ndims == 0:
+ return cand_i
+ # Otherwise propagate the old or the new value.
with ops.colocate_with(cand_i):
return array_ops.where(elements_finished, cur_i, cand_i)
return nest.map_structure(copy_fn, current, candidate)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index 30bf4e4ef1..009d1dc3b9 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -25,6 +25,7 @@ import sys as _sys
# Imports the following modules so that @RegisterGradient get executed.
from tensorflow.python.ops import array_grad
from tensorflow.python.ops import data_flow_grad
+from tensorflow.python.ops import manip_grad
from tensorflow.python.ops import math_grad
from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
@@ -42,11 +43,13 @@ from tensorflow.python.ops.special_math_ops import *
# TODO(vrv): Switch to import * once we're okay with exposing the module.
from tensorflow.python.ops.confusion_matrix import confusion_matrix
from tensorflow.python.ops.control_flow_ops import Assert
+from tensorflow.python.ops.control_flow_ops import case
+from tensorflow.python.ops.control_flow_ops import cond
from tensorflow.python.ops.control_flow_ops import group
from tensorflow.python.ops.control_flow_ops import no_op
+# pylint: disable=redefined-builtin
from tensorflow.python.ops.control_flow_ops import tuple
-from tensorflow.python.ops.control_flow_ops import cond
-from tensorflow.python.ops.control_flow_ops import case
+# pylint: enable=redefined-builtin
from tensorflow.python.ops.control_flow_ops import while_loop
from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.functional_ops import *
@@ -59,6 +62,7 @@ from tensorflow.python.ops.logging_ops import Print
from tensorflow.python.ops.logging_ops import get_summary_op
from tensorflow.python.ops.lookup_ops import initialize_all_tables
from tensorflow.python.ops.lookup_ops import tables_initializer
+from tensorflow.python.ops.manip_ops import *
from tensorflow.python.ops.math_ops import *
from tensorflow.python.ops.numerics import *
from tensorflow.python.ops.parsing_ops import *
@@ -105,6 +109,7 @@ from tensorflow.python.ops import init_ops as _init_ops
from tensorflow.python.ops import io_ops as _io_ops
from tensorflow.python.ops import linalg_ops as _linalg_ops
from tensorflow.python.ops import logging_ops as _logging_ops
+from tensorflow.python.ops import manip_ops as _manip_ops
from tensorflow.python.ops import math_ops as _math_ops
from tensorflow.python.ops import numerics as _numerics
from tensorflow.python.ops import parsing_ops as _parsing_ops
@@ -264,34 +269,36 @@ _allowed_symbols = (_allowed_symbols_array_ops +
_allowed_symbols_misc +
_allowed_symbols_partitioned_variables)
-remove_undocumented(__name__, _allowed_symbols,
- [_sys.modules[__name__],
- _array_ops,
- _check_ops,
- _clip_ops,
- _confusion_matrix,
- _control_flow_ops,
- _constant_op,
- _data_flow_ops,
- _functional_ops,
- _gradients,
- _histogram_ops,
- _init_ops,
- _io_ops,
- _linalg_ops,
- _logging_ops,
- _math_ops,
- _numerics,
- _parsing_ops,
- _partitioned_variables,
- _random_ops,
- _script_ops,
- _session_ops,
- _sparse_ops,
- _special_math_ops,
- _state_ops,
- _string_ops,
- _template,
- _tensor_array_ops,
- _variable_scope,
- _variables,])
+remove_undocumented(__name__, _allowed_symbols, [
+ _sys.modules[__name__],
+ _array_ops,
+ _check_ops,
+ _clip_ops,
+ _confusion_matrix,
+ _control_flow_ops,
+ _constant_op,
+ _data_flow_ops,
+ _functional_ops,
+ _gradients,
+ _histogram_ops,
+ _init_ops,
+ _io_ops,
+ _linalg_ops,
+ _logging_ops,
+ _manip_ops,
+ _math_ops,
+ _numerics,
+ _parsing_ops,
+ _partitioned_variables,
+ _random_ops,
+ _script_ops,
+ _session_ops,
+ _sparse_ops,
+ _special_math_ops,
+ _state_ops,
+ _string_ops,
+ _template,
+ _tensor_array_ops,
+ _variable_scope,
+ _variables,
+])