aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/array_ops.py15
-rw-r--r--tensorflow/python/ops/distributions/categorical.py2
-rw-r--r--tensorflow/python/ops/embedding_ops.py26
-rw-r--r--tensorflow/python/ops/histogram_ops.py1
-rw-r--r--tensorflow/python/ops/image_ops_impl.py74
-rw-r--r--tensorflow/python/ops/init_ops.py18
-rw-r--r--tensorflow/python/ops/linalg_ops.py77
-rw-r--r--tensorflow/python/ops/linalg_ops_impl.py73
-rw-r--r--tensorflow/python/ops/losses/losses_impl.py23
-rw-r--r--tensorflow/python/ops/math_ops.py38
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_impl.py11
-rw-r--r--tensorflow/python/ops/nn_ops.py8
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py4
14 files changed, 238 insertions, 133 deletions
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index fa26e07c85..ceeabe090d 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -144,6 +144,7 @@ def identity(input, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,protected-access
@tf_export("expand_dims")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead", "dim")
def expand_dims(input, axis=None, name=None, dim=None):
"""Inserts a dimension of 1 into a tensor's shape.
@@ -193,11 +194,7 @@ def expand_dims(input, axis=None, name=None, dim=None):
Raises:
ValueError: if both `dim` and `axis` are specified.
"""
- # TODO(aselle): Remove argument dim
- if dim is not None:
- if axis is not None:
- raise ValueError("can't specify both 'dim' and 'axis'")
- axis = dim
+ axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
return gen_array_ops.expand_dims(input, axis, name)
@@ -2581,6 +2578,8 @@ def sequence_mask(lengths, maxlen=None, dtype=dtypes.bool, name=None):
@tf_export("squeeze")
+@deprecation.deprecated_args(None, "Use the `axis` argument instead",
+ "squeeze_dims")
def squeeze(input, axis=None, name=None, squeeze_dims=None):
# pylint: disable=redefined-builtin
"""Removes dimensions of size 1 from the shape of a tensor.
@@ -2621,10 +2620,8 @@ def squeeze(input, axis=None, name=None, squeeze_dims=None):
Raises:
ValueError: When both `squeeze_dims` and `axis` are specified.
"""
- if squeeze_dims is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'squeeze_dims' and 'axis'")
- axis = squeeze_dims
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "squeeze_dims", squeeze_dims)
if np.isscalar(axis):
axis = [axis]
return gen_array_ops.squeeze(input, axis, name)
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index 66fa9e110c..8f25b1149c 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -311,7 +311,7 @@ class Categorical(distribution.Distribution):
nn_ops.log_softmax(self.logits) * self.probs, axis=-1)
def _mode(self):
- ret = math_ops.argmax(self.logits, dimension=self._batch_rank)
+ ret = math_ops.argmax(self.logits, axis=self._batch_rank)
ret = math_ops.cast(ret, self.dtype)
ret.set_shape(self.batch_shape)
return ret
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index f0120f2957..9e46739bc1 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -331,11 +331,11 @@ def embedding_lookup_sparse(params,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
- sp_ids: N x M SparseTensor of int64 ids (typically from FeatureValueToId),
+ sp_ids: N x M `SparseTensor` of int64 ids (typically from FeatureValueToId),
where N is typically batch size and M is arbitrary.
- sp_weights: either a SparseTensor of float / double weights, or None to
- indicate all weights should be taken to be 1. If specified, sp_weights
- must have exactly the same shape and indices as sp_ids.
+ sp_weights: either a `SparseTensor` of float / double weights, or `None` to
+ indicate all weights should be taken to be 1. If specified, `sp_weights`
+ must have exactly the same shape and indices as `sp_ids`.
partition_strategy: A string specifying the partitioning strategy, relevant
if `len(params) > 1`. Currently `"div"` and `"mod"` are supported. Default
is `"mod"`. See `tf.nn.embedding_lookup` for more details.
@@ -351,39 +351,43 @@ def embedding_lookup_sparse(params,
Returns:
A dense tensor representing the combined embeddings for the
- sparse ids. For each row in the dense tensor represented by sp_ids, the op
+ sparse ids. For each row in the dense tensor represented by `sp_ids`, the op
looks up the embeddings for all ids in that row, multiplies them by the
corresponding weight, and combines these embeddings as specified.
In other words, if
- shape(combined params) = [p0, p1, ..., pm]
+ `shape(combined params) = [p0, p1, ..., pm]`
and
- shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]
+ `shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]`
then
- shape(output) = [d0, d1, ..., dn-1, p1, ..., pm].
+ `shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]`.
For instance, if params is a 10x20 matrix, and sp_ids / sp_weights are
+ ```python
[0, 0]: id 1, weight 2.0
[0, 1]: id 3, weight 0.5
[1, 0]: id 0, weight 1.0
[2, 3]: id 1, weight 3.0
+ ```
with `combiner`="mean", then the output will be a 3x20 matrix where
+ ```python
output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
output[1, :] = (params[0, :] * 1.0) / 1.0
output[2, :] = (params[1, :] * 3.0) / 3.0
+ ```
Raises:
- TypeError: If sp_ids is not a SparseTensor, or if sp_weights is neither
- None nor SparseTensor.
- ValueError: If combiner is not one of {"mean", "sqrtn", "sum"}.
+ TypeError: If `sp_ids` is not a `SparseTensor`, or if `sp_weights` is
+ neither `None` nor `SparseTensor`.
+ ValueError: If `combiner` is not one of {"mean", "sqrtn", "sum"}.
"""
if combiner is None:
logging.warn("The default value of combiner will change from \"mean\" "
diff --git a/tensorflow/python/ops/histogram_ops.py b/tensorflow/python/ops/histogram_ops.py
index 4a1ef54fb5..ec38d89a0e 100644
--- a/tensorflow/python/ops/histogram_ops.py
+++ b/tensorflow/python/ops/histogram_ops.py
@@ -32,7 +32,6 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
-from tensorflow.python.util.tf_export import tf_export
@tf_export('histogram_fixed_width_bins')
diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py
index 3369fe3c9b..601010bce9 100644
--- a/tensorflow/python/ops/image_ops_impl.py
+++ b/tensorflow/python/ops/image_ops_impl.py
@@ -269,17 +269,7 @@ def random_flip_up_down(image, seed=None):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'random_flip_up_down', [image]) as scope:
- image = ops.convert_to_tensor(image, name='image')
- image = _Assert3DImage(image)
- uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
- mirror_cond = math_ops.less(uniform_random, .5)
- result = control_flow_ops.cond(
- mirror_cond,
- lambda: array_ops.reverse(image, [0]),
- lambda: image,
- name=scope)
- return fix_image_flip_shape(image, result)
+ return _random_flip(image, 0, seed, 'random_flip_up_down')
@tf_export('image.random_flip_left_right')
@@ -301,14 +291,34 @@ def random_flip_left_right(image, seed=None):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'random_flip_left_right', [image]) as scope:
+ return _random_flip(image, 1, seed, 'random_flip_left_right')
+
+
+def _random_flip(image, flip_index, seed, scope_name):
+ """Randomly (50% chance) flip an image along axis `flip_index`.
+ Args:
+ image: A 3-D tensor of shape `[height, width, channels].`
+ flip_index: The dimension along which to flip the image.
+ Vertical: 0, Horizontal: 1
+ seed: A Python integer. Used to create a random seed. See
+ @{tf.set_random_seed}
+ for behavior.
+ scope_name: Name of the scope in which the ops are added.
+
+ Returns:
+ A 3-D tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ with ops.name_scope(None, scope_name, [image]) as scope:
image = ops.convert_to_tensor(image, name='image')
image = _Assert3DImage(image)
uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
mirror_cond = math_ops.less(uniform_random, .5)
result = control_flow_ops.cond(
mirror_cond,
- lambda: array_ops.reverse(image, [1]),
+ lambda: array_ops.reverse(image, [flip_index]),
lambda: image,
name=scope)
return fix_image_flip_shape(image, result)
@@ -332,16 +342,7 @@ def flip_left_right(image):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_left_right', [image]):
- image = ops.convert_to_tensor(image, name='image')
- image = _AssertAtLeast3DImage(image)
- shape = image.get_shape()
- if shape.ndims == 3 or shape.ndims is None:
- return fix_image_flip_shape(image, array_ops.reverse(image, [1]))
- elif shape.ndims == 4:
- return array_ops.reverse(image, [2])
- else:
- raise ValueError('\'image\' must have either 3 or 4 dimensions.')
+ return _flip(image, 1, 'flip_left_right')
@tf_export('image.flip_up_down')
@@ -362,14 +363,35 @@ def flip_up_down(image):
Raises:
ValueError: if the shape of `image` not supported.
"""
- with ops.name_scope(None, 'flip_up_down', [image]):
+ return _flip(image, 0, 'flip_up_down')
+
+
+def _flip(image, flip_index, scope_name):
+ """Flip an image either horizontally or vertically.
+
+ Outputs the contents of `image` flipped along the dimension `flip_index`.
+
+ See also `reverse()`.
+
+ Args:
+ image: 4-D Tensor of shape `[batch, height, width, channels]` or
+ 3-D Tensor of shape `[height, width, channels]`.
+ flip_index: 0 For vertical, 1 for horizontal.
+
+ Returns:
+ A tensor of the same type and shape as `image`.
+
+ Raises:
+ ValueError: if the shape of `image` not supported.
+ """
+ with ops.name_scope(None, scope_name, [image]):
image = ops.convert_to_tensor(image, name='image')
image = _AssertAtLeast3DImage(image)
shape = image.get_shape()
if shape.ndims == 3 or shape.ndims is None:
- return fix_image_flip_shape(image, array_ops.reverse(image, [0]))
+ return fix_image_flip_shape(image, array_ops.reverse(image, [flip_index]))
elif shape.ndims == 4:
- return array_ops.reverse(image, [1])
+ return array_ops.reverse(image, [flip_index+1])
else:
raise ValueError('\'image\' must have either 3 or 4 dimensions.')
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index 39b7295124..f93bf0a17f 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -39,10 +39,10 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
+from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
-from tensorflow.python.ops import random_ops
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -529,7 +529,7 @@ class Orthogonal(Initializer):
# Generate a random matrix
a = random_ops.random_normal(flat_shape, dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, r = linalg_ops.qr(a, full_matrices=False)
+ q, r = gen_linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
q *= math_ops.sign(d)
@@ -577,7 +577,7 @@ class ConvolutionDeltaOrthogonal(Initializer):
a = random_ops.random_normal([shape[-1], shape[-1]],
dtype=dtype, seed=self.seed)
# Compute the qr factorization
- q, r = linalg_ops.qr(a, full_matrices=False)
+ q, r = gen_linalg_ops.qr(a, full_matrices=False)
# Make Q uniform
d = array_ops.diag_part(r)
q *= math_ops.sign(d)
@@ -636,7 +636,7 @@ class ConvolutionOrthogonal(Initializer):
a = random_ops.random_normal([n, n], dtype=self.dtype, seed=self.seed)
if self.seed:
self.seed += 1
- q, r = linalg_ops.qr(a)
+ q, r = gen_linalg_ops.qr(a)
d = array_ops.diag_part(r)
# make q uniform
q *= math_ops.sign(d)
@@ -723,7 +723,7 @@ class ConvolutionOrthogonal2D(ConvolutionOrthogonal):
raise ValueError("The dimension of the matrices must be the same.")
n = p1.shape.as_list()[0]
kernel2x2 = {}
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel2x2[0, 0] = math_ops.matmul(p1, p2)
kernel2x2[0, 1] = math_ops.matmul(p1, (eye - p2))
kernel2x2[1, 0] = math_ops.matmul((eye - p1), p2)
@@ -848,7 +848,7 @@ class ConvolutionOrthogonal1D(ConvolutionOrthogonal):
"""
n = projection_matrix.shape.as_list()[0]
kernel = {}
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel[0] = projection_matrix
kernel[1] = eye - projection_matrix
return kernel
@@ -976,7 +976,7 @@ class ConvolutionOrthogonal3D(ConvolutionOrthogonal):
if p1_shape != p2.shape.as_list() or p1_shape != p3.shape.as_list():
raise ValueError("The dimension of the matrices must be the same.")
n = p1_shape[0]
- eye = linalg_ops.eye(n, dtype=self.dtype)
+ eye = linalg_ops_impl.eye(n, dtype=self.dtype)
kernel2x2x2 = {}
def matmul(p1, p2, p3):
return math_ops.matmul(math_ops.matmul(p1, p2), p3)
@@ -1084,7 +1084,7 @@ class Identity(Initializer):
"Identity matrix initializer can only be used for 2D matrices.")
if dtype is None:
dtype = self.dtype
- initializer = linalg_ops.eye(*full_shape, dtype=dtype)
+ initializer = linalg_ops_impl.eye(*full_shape, dtype=dtype)
if partition_info is not None:
initializer = array_ops.slice(initializer, partition_info.var_offset,
shape)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index 170861b43f..a0dfa543f9 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -24,12 +24,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -159,36 +160,11 @@ def eye(num_rows,
Returns:
A `Tensor` of shape `batch_shape + [num_rows, num_columns]`
"""
- with ops.name_scope(
- name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
- is_square = num_columns is None
- batch_shape = [] if batch_shape is None else batch_shape
- num_columns = num_rows if num_columns is None else num_columns
- if isinstance(num_rows, ops.Tensor) or isinstance(
- num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
- batch_shape = ops.convert_to_tensor(
- batch_shape, name='shape', dtype=dtypes.int32)
- diag_size = math_ops.minimum(num_rows, num_columns)
- diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
- if not is_square:
- shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
- else:
- if not isinstance(num_rows, compat.integral_types) or not isinstance(
- num_columns, compat.integral_types):
- raise TypeError(
- 'num_rows and num_columns must be positive integer values.')
- batch_shape = [dim for dim in batch_shape]
- is_square = num_rows == num_columns
- diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
- if not is_square:
- shape = batch_shape + [num_rows, num_columns]
-
- diag_ones = array_ops.ones(diag_shape, dtype=dtype)
- if is_square:
- return array_ops.matrix_diag(diag_ones)
- else:
- zero_matrix = array_ops.zeros(shape, dtype=dtype)
- return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+ return linalg_ops_impl.eye(num_rows,
+ num_columns=num_columns,
+ batch_shape=batch_shape,
+ dtype=dtype,
+ name=name)
@tf_export('matrix_solve_ls', 'linalg.lstsq')
@@ -454,7 +430,7 @@ def norm(tensor,
This function can compute several different vector norms (the 1-norm, the
Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) and
- matrix norms (Frobenius, 1-norm, and inf-norm).
+ matrix norms (Frobenius, 1-norm, 2-norm and inf-norm).
Args:
tensor: `Tensor` of types `float32`, `float64`, `complex64`, `complex128`
@@ -465,7 +441,7 @@ def norm(tensor,
Some restrictions apply:
a) The Frobenius norm `fro` is not defined for vectors,
b) If axis is a 2-tuple (matrix norm), only 'euclidean', 'fro', `1`,
- `np.inf` are supported.
+ `2`, `np.inf` are supported.
See the description of `axis` on how to compute norms for a batch of
vectors or matrices stored in a tensor.
axis: If `axis` is `None` (the default), the input is considered a vector
@@ -521,8 +497,7 @@ def norm(tensor,
axis[0] == axis[1]):
raise ValueError(
"'axis' must be None, an integer, or a tuple of 2 unique integers")
- # TODO(rmlarsen): Implement matrix 2-norm using tf.svd().
- supported_matrix_norms = ['euclidean', 'fro', 1, np.inf]
+ supported_matrix_norms = ['euclidean', 'fro', 1, 2, np.inf]
if ord not in supported_matrix_norms:
raise ValueError("'ord' must be a supported matrix norm in %s, got %s" %
(supported_matrix_norms, ord))
@@ -539,12 +514,34 @@ def norm(tensor,
with ops.name_scope(name, 'norm', [tensor]):
tensor = ops.convert_to_tensor(tensor)
+
if ord in ['fro', 'euclidean', 2, 2.0]:
- # TODO(rmlarsen): Move 2-norm to a separate clause once we support it for
- # matrices.
- result = math_ops.sqrt(
- math_ops.reduce_sum(
- tensor * math_ops.conj(tensor), axis, keepdims=True))
+ if is_matrix_norm and ord in [2, 2.0]:
+ rank = array_ops.rank(tensor)
+ positive_axis = functional_ops.map_fn(
+ lambda i: control_flow_ops.cond(i >= 0, lambda: i, lambda: i + rank),
+ ops.convert_to_tensor(axis))
+ axes = math_ops.range(rank)
+ perm_before = array_ops.concat(
+ [array_ops.setdiff1d(axes, positive_axis)[0], positive_axis],
+ axis=0)
+ perm_after = functional_ops.map_fn(
+ lambda i: math_ops.cast(
+ array_ops.squeeze(
+ array_ops.where(math_ops.equal(perm_before, i))),
+ dtype=dtypes.int32), axes)
+ permed = array_ops.transpose(tensor, perm=perm_before)
+ matrix_2_norm = array_ops.expand_dims(
+ math_ops.reduce_max(
+ math_ops.abs(gen_linalg_ops.svd(permed, compute_uv=False)[0]),
+ axis=-1,
+ keepdims=True),
+ axis=-1)
+ result = array_ops.transpose(matrix_2_norm, perm=perm_after)
+ else:
+ result = math_ops.sqrt(
+ math_ops.reduce_sum(
+ tensor * math_ops.conj(tensor), axis, keepdims=True))
else:
result = math_ops.abs(tensor)
if ord == 1:
diff --git a/tensorflow/python/ops/linalg_ops_impl.py b/tensorflow/python/ops/linalg_ops_impl.py
new file mode 100644
index 0000000000..e7c89f6ae3
--- /dev/null
+++ b/tensorflow/python/ops/linalg_ops_impl.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Operations for linear algebra."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util import compat
+
+# Names below are lower_case.
+# pylint: disable=invalid-name
+
+
+def eye(num_rows,
+ num_columns=None,
+ batch_shape=None,
+ dtype=dtypes.float32,
+ name=None):
+ """Construct an identity matrix, or a batch of matrices.
+
+ See `linalg_ops.eye`.
+ """
+ with ops.name_scope(
+ name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
+ is_square = num_columns is None
+ batch_shape = [] if batch_shape is None else batch_shape
+ num_columns = num_rows if num_columns is None else num_columns
+ if isinstance(num_rows, ops.Tensor) or isinstance(
+ num_columns, ops.Tensor) or isinstance(batch_shape, ops.Tensor):
+ batch_shape = ops.convert_to_tensor(
+ batch_shape, name='shape', dtype=dtypes.int32)
+ diag_size = math_ops.minimum(num_rows, num_columns)
+ diag_shape = array_ops.concat((batch_shape, [diag_size]), 0)
+ if not is_square:
+ shape = array_ops.concat((batch_shape, [num_rows, num_columns]), 0)
+ else:
+ if not isinstance(num_rows, compat.integral_types) or not isinstance(
+ num_columns, compat.integral_types):
+ raise TypeError(
+ 'num_rows and num_columns must be positive integer values.')
+ batch_shape = [dim for dim in batch_shape]
+ is_square = num_rows == num_columns
+ diag_shape = batch_shape + [np.minimum(num_rows, num_columns)]
+ if not is_square:
+ shape = batch_shape + [num_rows, num_columns]
+
+ diag_ones = array_ops.ones(diag_shape, dtype=dtype)
+ if is_square:
+ return array_ops.matrix_diag(diag_ones)
+ else:
+ zero_matrix = array_ops.zeros(shape, dtype=dtype)
+ return array_ops.matrix_set_diag(zero_matrix, diag_ones)
+
+# pylint: enable=invalid-name,redefined-builtin
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 34ca1adc3e..9fc545c967 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.ops.losses import util
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@@ -306,11 +307,8 @@ def cosine_distance(
ValueError: If `predictions` shape doesn't match `labels` shape, or
`axis`, `labels`, `predictions` or `weights` is `None`.
"""
- if dim is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dim'")
- axis = dim
- if axis is None and dim is None:
+ axis = deprecated_argument_lookup("axis", axis, "dim", dim)
+ if axis is None:
raise ValueError("You must specify 'axis'.")
if labels is None:
raise ValueError("labels must not be None.")
@@ -696,7 +694,7 @@ def softmax_cross_entropy(
onehot_labels, logits, weights=1.0, label_smoothing=0, scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=Reduction.SUM_BY_NONZERO_WEIGHTS):
- """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
+ """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits_v2.
`weights` acts as a coefficient for the loss. If a scalar is provided,
then the loss is simply scaled by the given value. If `weights` is a
@@ -707,11 +705,16 @@ def softmax_cross_entropy(
new_onehot_labels = onehot_labels * (1 - label_smoothing)
+ label_smoothing / num_classes
+ Note that `onehot_labels` and `logits` must have the same shape,
+ e.g. `[batch_size, num_classes]`. The shape of `weights` must be
+ broadcastable to loss, whose shape is decided by the shape of `logits`.
+ In case the shape of `logits` is `[batch_size, num_classes]`, loss is
+ a `Tensor` of shape `[batch_size]`.
+
Args:
- onehot_labels: `[batch_size, num_classes]` target one-hot-encoded labels.
- logits: `[batch_size, num_classes]` logits outputs of the network .
- weights: Optional `Tensor` whose rank is either 0, or rank 1 and is
- broadcastable to the loss which is a `Tensor` of shape `[batch_size]`.
+ onehot_labels: One-hot-encoded labels.
+ logits: Logits outputs of the network.
+ weights: Optional `Tensor` that is broadcastable to loss.
label_smoothing: If greater than 0 then smooth the labels.
scope: the scope for the operations performed in computing the loss.
loss_collection: collection to which the loss will be added.
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 2b04866fef..2feb88cb7b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -211,11 +211,9 @@ def argmax(input,
name=None,
dimension=None,
output_type=dtypes.int64):
- if dimension is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dimension'")
- axis = dimension
- elif axis is None:
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "dimension", dimension)
+ if axis is None:
axis = 0
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
@@ -231,11 +229,9 @@ def argmin(input,
name=None,
dimension=None,
output_type=dtypes.int64):
- if dimension is not None:
- if axis is not None:
- raise ValueError("Cannot specify both 'axis' and 'dimension'")
- axis = dimension
- elif axis is None:
+ axis = deprecation.deprecated_argument_lookup(
+ "axis", axis, "dimension", dimension)
+ if axis is None:
axis = 0
return gen_math_ops.arg_min(input, axis, name=name, output_type=output_type)
@@ -761,13 +757,25 @@ def cast(x, dtype, name=None):
tf.cast(x, tf.int32) # [1, 2], dtype=tf.int32
```
+ The operation supports data types (for `x` and `dtype`) of
+ `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`, `float16`, `float32`,
+ `float64`, `complex64`, `complex128`, `bfloat16`. In case of casting from
+ complex types (`complex64`, `complex128`) to real types, only the real part
+ of `x` is returned. In case of casting from real types to complex types
+ (`complex64`, `complex128`), the imaginary part of the returned value is set
+ to `0`. The handling of complex types here matches the behavior of numpy.
+
Args:
- x: A `Tensor` or `SparseTensor`.
- dtype: The destination type.
+ x: A `Tensor` or `SparseTensor` of numeric type. It could be
+ `uint8`, `int8`, `uint16`, `int16`, `int32`, `int64`,
+ `float16`, `float32`, `float64`, `complex64`, `complex128`, `bfloat16`.
+ dtype: The destination type. The list of supported dtypes is the same
+ as `x`.
name: A name for the operation (optional).
Returns:
- A `Tensor` or `SparseTensor` with same shape as `x`.
+ A `Tensor` or `SparseTensor` with same shape as `x` and
+ same type as `dtype`.
Raises:
TypeError: If `x` cannot be cast to the `dtype`.
@@ -1634,7 +1642,7 @@ def reduce_min(input_tensor,
tensor with a single element is returned.
Args:
- input_tensor: The tensor to reduce. Should have numeric type.
+ input_tensor: The tensor to reduce. Should have real numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
@@ -1683,7 +1691,7 @@ def reduce_max(input_tensor,
tensor with a single element is returned.
Args:
- input_tensor: The tensor to reduce. Should have numeric type.
+ input_tensor: The tensor to reduce. Should have real numeric type.
axis: The dimensions to reduce. If `None` (the default),
reduces all dimensions. Must be in the range
`[-rank(input_tensor), rank(input_tensor))`.
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 244702d13b..1d0d9a52a1 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -98,6 +98,7 @@ See the @{$python/nn} guide.
@@fixed_unigram_candidate_sampler
@@compute_accidental_hits
@@quantized_conv2d
+@@quantized_relu
@@quantized_relu_x
@@quantized_max_pool
@@quantized_avg_pool
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 47cc4da7f2..d0d5ed07ce 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -987,7 +987,7 @@ def _compute_sampled_logits(weights,
class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
- the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
@@ -1012,7 +1012,7 @@ def _compute_sampled_logits(weights,
out_logits: `Tensor` object with shape
`[batch_size, num_true + num_sampled]`, for passing to either
`nn.sigmoid_cross_entropy_with_logits` (NCE) or
- `nn.softmax_cross_entropy_with_logits` (sampled softmax).
+ `nn.softmax_cross_entropy_with_logits_v2` (sampled softmax).
out_labels: A Tensor object with the same shape as `out_logits`.
"""
@@ -1285,7 +1285,7 @@ def sampled_softmax_loss(weights,
logits = tf.matmul(inputs, tf.transpose(weights))
logits = tf.nn.bias_add(logits, biases)
labels_one_hot = tf.one_hot(labels, n_classes)
- loss = tf.nn.softmax_cross_entropy_with_logits(
+ loss = tf.nn.softmax_cross_entropy_with_logits_v2(
labels=labels_one_hot,
logits=logits)
```
@@ -1303,7 +1303,7 @@ def sampled_softmax_loss(weights,
biases: A `Tensor` of shape `[num_classes]`. The class biases.
labels: A `Tensor` of type `int64` and shape `[batch_size,
num_true]`. The target classes. Note that this format differs from
- the `labels` argument of `nn.softmax_cross_entropy_with_logits`.
+ the `labels` argument of `nn.softmax_cross_entropy_with_logits_v2`.
inputs: A `Tensor` of shape `[batch_size, dim]`. The forward
activations of the input network.
num_sampled: An `int`. The number of classes to randomly sample per batch.
@@ -1340,7 +1340,8 @@ def sampled_softmax_loss(weights,
partition_strategy=partition_strategy,
name=name,
seed=seed)
- sampled_losses = nn_ops.softmax_cross_entropy_with_logits(
+ labels = array_ops.stop_gradient(labels, name="labels_stop_gradient")
+ sampled_losses = nn_ops.softmax_cross_entropy_with_logits_v2(
labels=labels, logits=logits)
# sampled_losses is a [batch_size] tensor.
return sampled_losses
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index bb454b3c3a..cd07550d2e 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1155,7 +1155,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
Returns:
A `Tensor` with the same type as `value`.
- Output shape with `'VALID`` padding is:
+ Output shape with `'VALID'` padding is:
[batch, height - 2 * (filter_width - 1),
width - 2 * (filter_height - 1), out_channels].
@@ -1458,10 +1458,10 @@ def conv3d_transpose(
if isinstance(output_shape, (list, np.ndarray)):
# output_shape's shape should be == [5] if reached this point.
- if not filter.get_shape()[3].is_compatible_with(output_shape[4]):
+ if not filter.get_shape()[3].is_compatible_with(output_shape[axis]):
raise ValueError(
"output_shape does not match filter's output channels, "
- "{} != {}".format(output_shape[4],
+ "{} != {}".format(output_shape[axis],
filter.get_shape()[3]))
if padding != "VALID" and padding != "SAME":
@@ -1986,7 +1986,7 @@ def sparse_softmax_cross_entropy_with_logits(
must provide a single specific index for the true class for each row of
`logits` (each minibatch entry). For soft softmax classification with
a probability distribution for each entry, see
- `softmax_cross_entropy_with_logits`.
+ `softmax_cross_entropy_with_logits_v2`.
**WARNING:** This op expects unscaled logits, since it performs a `softmax`
on `logits` internally for efficiency. Do not call this op with the
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 9251e9802c..86dc053c0f 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -617,9 +617,9 @@ class BasicLSTMCell(LayerRNNCell):
Args:
inputs: `2-D` tensor with shape `[batch_size, input_size]`.
state: An `LSTMStateTuple` of state tensors, each shaped
- `[batch_size, self.state_size]`, if `state_is_tuple` has been set to
+ `[batch_size, num_units]`, if `state_is_tuple` has been set to
`True`. Otherwise, a `Tensor` shaped
- `[batch_size, 2 * self.state_size]`.
+ `[batch_size, 2 * num_units]`.
Returns:
A pair containing the new hidden state, and the new state (either a