aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/ops/nn_ops.cc4
-rw-r--r--tensorflow/python/kernel_tests/xent_op_test.py2
-rw-r--r--tensorflow/python/ops/linalg_ops.py199
-rw-r--r--tensorflow/python/ops/nn_ops.py191
4 files changed, 58 insertions, 338 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index fe78541d08..cc374278e7 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -1263,7 +1263,7 @@ REGISTER_OP("Softmax")
.Output("softmax: T")
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- return shape_inference::UnchangedShapeWithRank(c, 2);
+ return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
.Doc(R"doc(
Computes softmax activations.
@@ -1283,7 +1283,7 @@ REGISTER_OP("LogSoftmax")
.Output("logsoftmax: T")
.Attr("T: {half, float, double}")
.SetShapeFn([](InferenceContext* c) {
- return shape_inference::UnchangedShapeWithRank(c, 2);
+ return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
.Doc(R"doc(
Computes log softmax activations.
diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py
index 70b3bdcbb4..3c843d27c8 100644
--- a/tensorflow/python/kernel_tests/xent_op_test.py
+++ b/tensorflow/python/kernel_tests/xent_op_test.py
@@ -70,7 +70,7 @@ class XentTest(tf.test.TestCase):
[[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(dtype)
np_labels = np.array(
[[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype)
- self.assertRaisesRegexp(ValueError, "must have rank 2",
+ self.assertRaisesRegexp(ValueError, "must be rank 2",
gen_nn_ops._softmax_cross_entropy_with_logits,
np_features, np_labels)
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index fa8b88bf06..bd753c12ec 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_linalg_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
@@ -27,183 +27,26 @@ from tensorflow.python.ops.gen_linalg_ops import *
# pylint: enable=wildcard-import
-def _UnchangedSquareHelper(input_shape):
- """Helper for {Batch}UnchangedSquare."""
- # The matrices in the batch must be square.
- input_shape[-1].assert_is_compatible_with(input_shape[-2])
- return [input_shape]
-
-
-@ops.RegisterShape("Cholesky")
-@ops.RegisterShape("CholeskyGrad")
-@ops.RegisterShape("MatrixInverse")
-def _UnchangedSquare(op):
- """Shape function for matrix ops with output equal to input shape."""
- return _UnchangedSquareHelper(op.inputs[0].get_shape().with_rank(2))
-
-
-@ops.RegisterShape("BatchCholesky")
-@ops.RegisterShape("BatchCholeskyGrad")
-@ops.RegisterShape("BatchMatrixInverse")
-def _BatchUnchangedSquare(op):
- """Shape function for batch matrix ops with output equal to input shape."""
- return _UnchangedSquareHelper(op.inputs[0].get_shape().with_rank_at_least(2))
-
-
-@ops.RegisterShape("MatrixDeterminant")
-def _MatrixDeterminantShape(op):
- """Shape function for determinant op."""
- input_shape = op.inputs[0].get_shape().with_rank(2)
- # The matrix must be square.
- input_shape[0].assert_is_compatible_with(input_shape[1])
- if input_shape.ndims is not None:
- return [tensor_shape.scalar()]
- else:
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("BatchMatrixDeterminant")
-def _BatchMatrixDeterminantShape(op):
- """Shape function for batch determinant op."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
- # The matrices in the batch must be square.
- input_shape[-1].assert_is_compatible_with(input_shape[-2])
- if input_shape.ndims is not None:
- return [input_shape[:-2]]
- else:
- return [tensor_shape.unknown_shape()]
-
-
-@ops.RegisterShape("SelfAdjointEig")
-def _SelfAdjointEigShape(op):
- """Shape function for self-adjoint eigensolver op."""
- input_shape = op.inputs[0].get_shape().with_rank(2)
- # The matrix must be square.
- input_shape[0].assert_is_compatible_with(input_shape[1])
- d = input_shape.dims[0]
- out_shape = tensor_shape.TensorShape([d + 1, d])
- return [out_shape]
-
-
-@ops.RegisterShape("BatchSelfAdjointEig")
-def _BatchSelfAdjointEigShape(op):
- """Shape function for batch self-adjoint eigensolver op."""
- input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
- # The matrices in the batch must be square.
- input_shape[-1].assert_is_compatible_with(input_shape[-2])
- dlist = input_shape.dims
- dlist[-2] += 1
- out_shape = tensor_shape.TensorShape(dlist)
- return [out_shape]
-
-
-def _SelfAdjointEigV2ShapeHelper(op, input_shape):
- """Shape inference helper for {Batch}SelfAdjointEigV2."""
- batch_shape = input_shape[:-2]
- n = input_shape[-1].merge_with(input_shape[-2])
- compute_v = op.get_attr("compute_v")
- if compute_v:
- return [batch_shape.concatenate([n]), batch_shape.concatenate([n, n])]
- else:
- return [batch_shape.concatenate([n]), [0]]
-
-
-@ops.RegisterShape("SelfAdjointEigV2")
-def _SelfAdjointEigShapeV2(op):
- """Shape function for SelfAdjointEigV2."""
- return _SelfAdjointEigV2ShapeHelper(op, op.inputs[0].get_shape().with_rank(2))
-
-
-@ops.RegisterShape("BatchSelfAdjointEigV2")
-def _BatchSelfAdjointEigV2Shape(op):
- """Shape function for BatchSelfAdjointEigV2."""
- return _SelfAdjointEigV2ShapeHelper(
- op, op.inputs[0].get_shape().with_rank_at_least(2))
-
-
-def _SvdShapeHelper(input_shape, op):
- """Shape inference helper for {Batch}SVD op."""
- unknown = tensor_shape.unknown_shape()
- if input_shape.ndims is not None:
- return [unknown, unknown, unknown]
- compute_uv = op.get_attr("compute_uv")
- full_matrices = op.get_attr("full_matrices")
- m = input_shape[-2]
- n = input_shape[-1]
- p = min(m, n)
- batch_shape = input_shape[:-2]
- s_shape = batch_shape.concatenate([p])
- if compute_uv:
- if full_matrices:
- u_shape = batch_shape.concatenate([m, m])
- v_shape = batch_shape.concatenate([n, n])
- else:
- u_shape = batch_shape.concatenate([m, p])
- v_shape = batch_shape.concatenate([n, p])
- else:
- u_shape = [0]
- v_shape = [0]
- return [s_shape, u_shape, v_shape]
-
-
-@ops.RegisterShape("Svd")
-def _SvdShape(op):
- """Shape function for SVD op."""
- return _SvdShapeHelper(op.inputs[0].get_shape().with_rank(2), op)
-
-
-@ops.RegisterShape("BatchSvd")
-def _BatchSvdShape(op):
- """Shape function for batch SVD op."""
- return _SvdShapeHelper(op.inputs[0].get_shape().with_rank_at_least(2), op)
-
-
-def _SquareMatrixSolveShapeHelper(lhs_shape, rhs_shape):
- """Shape inference helper function for square matrix solver ops."""
- # The matrix must be square.
- lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2])
- # The matrix and right-hand side must have the same number of rows.
- lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
- return [rhs_shape]
-
-
-@ops.RegisterShape("MatrixSolve")
-@ops.RegisterShape("MatrixTriangularSolve")
-def _SquareMatrixSolveShape(op):
- """Shape function for square matrix solver ops."""
- return _SquareMatrixSolveShapeHelper(op.inputs[0].get_shape().with_rank(2),
- op.inputs[1].get_shape().with_rank(2))
-
-
-@ops.RegisterShape("BatchMatrixSolve")
-@ops.RegisterShape("BatchMatrixTriangularSolve")
-def _BatchSquareMatrixSolveShape(op):
- """Shape function for batch square matrix solver ops."""
- return _SquareMatrixSolveShapeHelper(
- op.inputs[0].get_shape().with_rank_at_least(2),
- op.inputs[1].get_shape().with_rank_at_least(2))
-
-
-def _MatrixSolveLsShapeHelper(lhs_shape, rhs_shape):
- """Shape inference helper function for least squares matrix solver ops."""
- # The matrices and right-hand sides must have the same number of rows.
- lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2])
- return [lhs_shape[:-2].concatenate([lhs_shape[-1], rhs_shape[-1]])]
-
-
-@ops.RegisterShape("MatrixSolveLs")
-def _MatrixSolveLsShape(op):
- """Shape function for least-squares matrix solver op."""
- return _MatrixSolveLsShapeHelper(op.inputs[0].get_shape().with_rank(2),
- op.inputs[1].get_shape().with_rank(2))
-
-
-@ops.RegisterShape("BatchMatrixSolveLs")
-def _BatchMatrixSolveLsShape(op):
- """Shape function for batch least-squares matrix solver op."""
- return _MatrixSolveLsShapeHelper(
- op.inputs[0].get_shape().with_rank_at_least(2),
- op.inputs[1].get_shape().with_rank_at_least(2))
+ops.RegisterShape("Cholesky")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("CholeskyGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatrixInverse")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchCholesky")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchCholeskyGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchMatrixInverse")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatrixDeterminant")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchMatrixDeterminant")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SelfAdjointEig")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchSelfAdjointEig")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SelfAdjointEigV2")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchSelfAdjointEigV2")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Svd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchSvd")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatrixSolve")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatrixTriangularSolve")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchMatrixSolve")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchMatrixTriangularSolve")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MatrixSolveLs")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchMatrixSolveLs")(common_shapes.call_cpp_shape_fn)
# Names below are lower_case.
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index f5f5aedc01..e14ca1a559 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -731,25 +731,10 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None):
return cost
-@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")
-def _SparseSoftmaxCrossEntropyWithLogitsShape(op):
- """Shape function for SparseSoftmaxCrossEntropyWithLogits op."""
- logits_shape = op.inputs[0].get_shape()
- input_shape = logits_shape.with_rank(2)
- batch_size = input_shape[0]
- # labels_shape
- op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size))
- return [tensor_shape.vector(batch_size.value), input_shape]
-
-
-@ops.RegisterShape("SoftmaxCrossEntropyWithLogits")
-def _SoftmaxCrossEntropyWithLogitsShape(op):
- """Shape function for SoftmaxCrossEntropyWithLogits op."""
- logits_shape = op.inputs[0].get_shape()
- labels_shape = op.inputs[1].get_shape()
- input_shape = logits_shape.merge_with(labels_shape).with_rank(2)
- batch_size = input_shape[0]
- return [tensor_shape.vector(batch_size.value), input_shape]
+ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")(
+ common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SoftmaxCrossEntropyWithLogits")(
+ common_shapes.call_cpp_shape_fn)
def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
@@ -812,58 +797,22 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
name=name)
-ops.RegisterShape("Relu")(common_shapes.unchanged_shape)
-ops.RegisterShape("Relu6")(common_shapes.unchanged_shape)
-ops.RegisterShape("Elu")(common_shapes.unchanged_shape)
-ops.RegisterShape("Softplus")(common_shapes.unchanged_shape)
-ops.RegisterShape("Softsign")(common_shapes.unchanged_shape)
-
-
-@ops.RegisterShape("ReluGrad")
-@ops.RegisterShape("Relu6Grad")
-@ops.RegisterShape("EluGrad")
-@ops.RegisterShape("SoftplusGrad")
-@ops.RegisterShape("SoftsignGrad")
-def _BinaryElementwiseShape(op):
- """Returns same shape as both inputs to op.
-
- Args:
- op: Input operation.
-
- Returns:
- Shape of both inputs to `op`.
- """
- return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())]
-
-
-ops.RegisterShape("L2Loss")(common_shapes.scalar_shape)
-
-ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4))
-
-
-@ops.RegisterShape("LRNGrad")
-def _LRNGradShape(op):
- """Shape function for LRNGrad op."""
- in_grads_shape = op.inputs[0].get_shape().with_rank(4)
- in_image_shape = op.inputs[1].get_shape().with_rank(4)
- out_image_shape = op.inputs[2].get_shape().with_rank(4)
- return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)]
-
-
-ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank_at_least(
- 1))
-
-ops.RegisterShape("LogSoftmax")(
- common_shapes.unchanged_shape_with_rank_at_least(1))
-
-
-@ops.RegisterShape("InTopK")
-def _InTopKShape(op):
- """Shape function for InTopK op."""
- predictions_shape = op.inputs[0].get_shape().with_rank(2)
- targets_shape = op.inputs[1].get_shape().with_rank(1)
- batch_size = predictions_shape[0].merge_with(targets_shape[0])
- return [tensor_shape.vector(batch_size.value)]
+ops.RegisterShape("Relu")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Relu6")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Elu")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Softplus")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Softsign")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("ReluGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Relu6Grad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("EluGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SoftplusGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("SoftsignGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("L2Loss")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("LRN")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("TopK")
@@ -883,36 +832,10 @@ def _TopKShape(op):
return [output_shape, output_shape]
-@ops.RegisterShape("BatchNormWithGlobalNormalization")
-def _BatchNormShape(op):
- """Shape function for BatchNormWithGlobalNormalization op."""
- input_shape = op.inputs[0].get_shape().with_rank(4)
- mean_shape = op.inputs[1].get_shape().with_rank(1)
- var_shape = op.inputs[2].get_shape().with_rank(1)
- beta_shape = op.inputs[3].get_shape().with_rank(1)
- gamma_shape = op.inputs[4].get_shape().with_rank(1)
- mean_shape[0].merge_with(input_shape[3])
- var_shape[0].merge_with(input_shape[3])
- beta_shape[0].merge_with(input_shape[3])
- gamma_shape[0].merge_with(input_shape[3])
- return [input_shape]
-
-
-@ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")
-def _BatchNormGradShape(op):
- """Shape function for BatchNormWithGlobalNormalizationGrad op."""
- input_shape = op.inputs[0].get_shape().with_rank(4)
- mean_shape = op.inputs[1].get_shape().with_rank(1)
- var_shape = op.inputs[2].get_shape().with_rank(1)
- beta_shape = op.inputs[3].get_shape().with_rank(1)
- out_backprop_shape = op.inputs[4].get_shape().with_rank(4)
- input_shape = input_shape.merge_with(out_backprop_shape)
- vector_dim = input_shape[3]
- vector_dim = vector_dim.merge_with(mean_shape[0])
- vector_dim = vector_dim.merge_with(var_shape[0])
- vector_dim = vector_dim.merge_with(beta_shape[0])
- return [input_shape] + ([tensor_shape.vector(vector_dim)] * 4)
-
+ops.RegisterShape("BatchNormWithGlobalNormalization")(
+ common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")(
+ common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape)
ops.RegisterShape("DepthwiseConv2dNative")(
@@ -1022,12 +945,8 @@ def _DepthwiseConv2dNativeBackpropInputShape(op):
return [tensor_shape.unknown_shape(ndims=4)]
-@ops.RegisterShape("MaxPoolGrad")
-@ops.RegisterShape("MaxPoolGradWithArgmax")
-def _MaxPoolGradShape(op):
- """Shape function for the MaxPoolGrad op."""
- orig_input_shape = op.inputs[0].get_shape().with_rank(4)
- return [orig_input_shape]
+ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MaxPoolGradWithArgmax")(common_shapes.call_cpp_shape_fn)
@ops.RegisterStatistics("Conv2D", "flops")
@@ -1150,46 +1069,12 @@ def _Pool3DShape(op):
channels])]
-@ops.RegisterShape("Conv3DBackpropFilter")
-def _Conv3DBackpropFilterShape(op):
- """Shape function for the Conv3DBackpropFilter op."""
- filter_shape = op.inputs[1].get_shape()
- return [filter_shape.with_rank(5)]
-
-
-@ops.RegisterShape("Conv3DBackpropInput")
-def _Conv3DBackpropInputShape(op):
- """Shape function for the Conv3DBackpropInput op."""
- input_shape = op.inputs[0].get_shape()
- return [input_shape.with_rank(5)]
-
-
-@ops.RegisterShape("Conv3DBackpropFilterV2")
-def _Conv3DBackpropFilterShapeV2(op):
- """Shape function for the Conv3DBackpropFilterV2 op."""
- filter_shape = tensor_util.constant_value(op.inputs[1])
- return [tensor_shape.TensorShape(filter_shape).with_rank(5)]
-
-
-@ops.RegisterShape("Conv3DBackpropInputV2")
-def _Conv3DBackpropInputShapeV2(op):
- """Shape function for the Conv3DBackpropInputV2 op."""
- input_shape = tensor_util.constant_value(op.inputs[0])
- return [tensor_shape.TensorShape(input_shape).with_rank(5)]
-
-
-@ops.RegisterShape("AvgPool3DGrad")
-def _AvgPool3DGradShape(op):
- """Shape function for the AvgPool3DGrad op."""
- orig_input_shape = tensor_util.constant_value(op.inputs[0])
- return [tensor_shape.TensorShape(orig_input_shape).with_rank(5)]
-
-
-@ops.RegisterShape("MaxPool3DGrad")
-def _MaxPool3DGradShape(op):
- """Shape function for the MaxPoolGrad op."""
- orig_input_shape = op.inputs[0].get_shape().with_rank(5)
- return [orig_input_shape]
+ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Conv3DBackpropInputV2")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("AvgPool3DGrad")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("MaxPool3DGrad")(common_shapes.call_cpp_shape_fn)
@ops.RegisterStatistics("BiasAdd", "flops")
@@ -1431,16 +1316,8 @@ def _Dilation2DShape(op):
return [tensor_shape.TensorShape(output_shape)]
-@ops.RegisterShape("Dilation2DBackpropInput")
-def _Dilation2DBackpropInputShape(op):
- """Shape function for Dilation2DBackpropInput op."""
- return [op.inputs[0].get_shape()]
-
-
-@ops.RegisterShape("Dilation2DBackpropFilter")
-def _Dilation2DBackpropFilterShape(op):
- """Shape function for Dilation2DBackpropFilter op."""
- return [op.inputs[1].get_shape()]
+ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
+ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)
@ops.RegisterStatistics("Dilation2D", "flops")