diff options
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/xent_op_test.py | 2 | ||||
-rw-r--r-- | tensorflow/python/ops/linalg_ops.py | 199 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 191 |
4 files changed, 58 insertions, 338 deletions
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index fe78541d08..cc374278e7 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1263,7 +1263,7 @@ REGISTER_OP("Softmax") .Output("softmax: T") .Attr("T: {half, float, double}") .SetShapeFn([](InferenceContext* c) { - return shape_inference::UnchangedShapeWithRank(c, 2); + return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) .Doc(R"doc( Computes softmax activations. @@ -1283,7 +1283,7 @@ REGISTER_OP("LogSoftmax") .Output("logsoftmax: T") .Attr("T: {half, float, double}") .SetShapeFn([](InferenceContext* c) { - return shape_inference::UnchangedShapeWithRank(c, 2); + return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); }) .Doc(R"doc( Computes log softmax activations. diff --git a/tensorflow/python/kernel_tests/xent_op_test.py b/tensorflow/python/kernel_tests/xent_op_test.py index 70b3bdcbb4..3c843d27c8 100644 --- a/tensorflow/python/kernel_tests/xent_op_test.py +++ b/tensorflow/python/kernel_tests/xent_op_test.py @@ -70,7 +70,7 @@ class XentTest(tf.test.TestCase): [[[1., 1., 1., 1.]], [[1., 2., 3., 4.]]]).astype(dtype) np_labels = np.array( [[[0., 0., 0., 1.]], [[0., .5, .5, 0.]]]).astype(dtype) - self.assertRaisesRegexp(ValueError, "must have rank 2", + self.assertRaisesRegexp(ValueError, "must be rank 2", gen_nn_ops._softmax_cross_entropy_with_logits, np_features, np_labels) diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py index fa8b88bf06..bd753c12ec 100644 --- a/tensorflow/python/ops/linalg_ops.py +++ b/tensorflow/python/ops/linalg_ops.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import common_shapes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_linalg_ops # go/tf-wildcard-import # pylint: disable=wildcard-import @@ -27,183 +27,26 @@ from tensorflow.python.ops.gen_linalg_ops import * # pylint: enable=wildcard-import -def _UnchangedSquareHelper(input_shape): - """Helper for {Batch}UnchangedSquare.""" - # The matrices in the batch must be square. - input_shape[-1].assert_is_compatible_with(input_shape[-2]) - return [input_shape] - - -@ops.RegisterShape("Cholesky") -@ops.RegisterShape("CholeskyGrad") -@ops.RegisterShape("MatrixInverse") -def _UnchangedSquare(op): - """Shape function for matrix ops with output equal to input shape.""" - return _UnchangedSquareHelper(op.inputs[0].get_shape().with_rank(2)) - - -@ops.RegisterShape("BatchCholesky") -@ops.RegisterShape("BatchCholeskyGrad") -@ops.RegisterShape("BatchMatrixInverse") -def _BatchUnchangedSquare(op): - """Shape function for batch matrix ops with output equal to input shape.""" - return _UnchangedSquareHelper(op.inputs[0].get_shape().with_rank_at_least(2)) - - -@ops.RegisterShape("MatrixDeterminant") -def _MatrixDeterminantShape(op): - """Shape function for determinant op.""" - input_shape = op.inputs[0].get_shape().with_rank(2) - # The matrix must be square. - input_shape[0].assert_is_compatible_with(input_shape[1]) - if input_shape.ndims is not None: - return [tensor_shape.scalar()] - else: - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("BatchMatrixDeterminant") -def _BatchMatrixDeterminantShape(op): - """Shape function for batch determinant op.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(2) - # The matrices in the batch must be square. - input_shape[-1].assert_is_compatible_with(input_shape[-2]) - if input_shape.ndims is not None: - return [input_shape[:-2]] - else: - return [tensor_shape.unknown_shape()] - - -@ops.RegisterShape("SelfAdjointEig") -def _SelfAdjointEigShape(op): - """Shape function for self-adjoint eigensolver op.""" - input_shape = op.inputs[0].get_shape().with_rank(2) - # The matrix must be square. - input_shape[0].assert_is_compatible_with(input_shape[1]) - d = input_shape.dims[0] - out_shape = tensor_shape.TensorShape([d + 1, d]) - return [out_shape] - - -@ops.RegisterShape("BatchSelfAdjointEig") -def _BatchSelfAdjointEigShape(op): - """Shape function for batch self-adjoint eigensolver op.""" - input_shape = op.inputs[0].get_shape().with_rank_at_least(2) - # The matrices in the batch must be square. - input_shape[-1].assert_is_compatible_with(input_shape[-2]) - dlist = input_shape.dims - dlist[-2] += 1 - out_shape = tensor_shape.TensorShape(dlist) - return [out_shape] - - -def _SelfAdjointEigV2ShapeHelper(op, input_shape): - """Shape inference helper for {Batch}SelfAdjointEigV2.""" - batch_shape = input_shape[:-2] - n = input_shape[-1].merge_with(input_shape[-2]) - compute_v = op.get_attr("compute_v") - if compute_v: - return [batch_shape.concatenate([n]), batch_shape.concatenate([n, n])] - else: - return [batch_shape.concatenate([n]), [0]] - - -@ops.RegisterShape("SelfAdjointEigV2") -def _SelfAdjointEigShapeV2(op): - """Shape function for SelfAdjointEigV2.""" - return _SelfAdjointEigV2ShapeHelper(op, op.inputs[0].get_shape().with_rank(2)) - - -@ops.RegisterShape("BatchSelfAdjointEigV2") -def _BatchSelfAdjointEigV2Shape(op): - """Shape function for BatchSelfAdjointEigV2.""" - return _SelfAdjointEigV2ShapeHelper( - op, op.inputs[0].get_shape().with_rank_at_least(2)) - - -def _SvdShapeHelper(input_shape, op): - """Shape inference helper for {Batch}SVD op.""" - unknown = tensor_shape.unknown_shape() - if input_shape.ndims is not None: - return [unknown, unknown, unknown] - compute_uv = op.get_attr("compute_uv") - full_matrices = op.get_attr("full_matrices") - m = input_shape[-2] - n = input_shape[-1] - p = min(m, n) - batch_shape = input_shape[:-2] - s_shape = batch_shape.concatenate([p]) - if compute_uv: - if full_matrices: - u_shape = batch_shape.concatenate([m, m]) - v_shape = batch_shape.concatenate([n, n]) - else: - u_shape = batch_shape.concatenate([m, p]) - v_shape = batch_shape.concatenate([n, p]) - else: - u_shape = [0] - v_shape = [0] - return [s_shape, u_shape, v_shape] - - -@ops.RegisterShape("Svd") -def _SvdShape(op): - """Shape function for SVD op.""" - return _SvdShapeHelper(op.inputs[0].get_shape().with_rank(2), op) - - -@ops.RegisterShape("BatchSvd") -def _BatchSvdShape(op): - """Shape function for batch SVD op.""" - return _SvdShapeHelper(op.inputs[0].get_shape().with_rank_at_least(2), op) - - -def _SquareMatrixSolveShapeHelper(lhs_shape, rhs_shape): - """Shape inference helper function for square matrix solver ops.""" - # The matrix must be square. - lhs_shape[-1].assert_is_compatible_with(lhs_shape[-2]) - # The matrix and right-hand side must have the same number of rows. - lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2]) - return [rhs_shape] - - -@ops.RegisterShape("MatrixSolve") -@ops.RegisterShape("MatrixTriangularSolve") -def _SquareMatrixSolveShape(op): - """Shape function for square matrix solver ops.""" - return _SquareMatrixSolveShapeHelper(op.inputs[0].get_shape().with_rank(2), - op.inputs[1].get_shape().with_rank(2)) - - -@ops.RegisterShape("BatchMatrixSolve") -@ops.RegisterShape("BatchMatrixTriangularSolve") -def _BatchSquareMatrixSolveShape(op): - """Shape function for batch square matrix solver ops.""" - return _SquareMatrixSolveShapeHelper( - op.inputs[0].get_shape().with_rank_at_least(2), - op.inputs[1].get_shape().with_rank_at_least(2)) - - -def _MatrixSolveLsShapeHelper(lhs_shape, rhs_shape): - """Shape inference helper function for least squares matrix solver ops.""" - # The matrices and right-hand sides must have the same number of rows. - lhs_shape[-2].assert_is_compatible_with(rhs_shape[-2]) - return [lhs_shape[:-2].concatenate([lhs_shape[-1], rhs_shape[-1]])] - - -@ops.RegisterShape("MatrixSolveLs") -def _MatrixSolveLsShape(op): - """Shape function for least-squares matrix solver op.""" - return _MatrixSolveLsShapeHelper(op.inputs[0].get_shape().with_rank(2), - op.inputs[1].get_shape().with_rank(2)) - - -@ops.RegisterShape("BatchMatrixSolveLs") -def _BatchMatrixSolveLsShape(op): - """Shape function for batch least-squares matrix solver op.""" - return _MatrixSolveLsShapeHelper( - op.inputs[0].get_shape().with_rank_at_least(2), - op.inputs[1].get_shape().with_rank_at_least(2)) +ops.RegisterShape("Cholesky")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("CholeskyGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatrixInverse")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchCholesky")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchCholeskyGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchMatrixInverse")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatrixDeterminant")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchMatrixDeterminant")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SelfAdjointEig")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchSelfAdjointEig")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SelfAdjointEigV2")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchSelfAdjointEigV2")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Svd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchSvd")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatrixSolve")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatrixTriangularSolve")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchMatrixSolve")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchMatrixTriangularSolve")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MatrixSolveLs")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchMatrixSolveLs")(common_shapes.call_cpp_shape_fn) # Names below are lower_case. diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index f5f5aedc01..e14ca1a559 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -731,25 +731,10 @@ def sparse_softmax_cross_entropy_with_logits(logits, labels, name=None): return cost -@ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits") -def _SparseSoftmaxCrossEntropyWithLogitsShape(op): - """Shape function for SparseSoftmaxCrossEntropyWithLogits op.""" - logits_shape = op.inputs[0].get_shape() - input_shape = logits_shape.with_rank(2) - batch_size = input_shape[0] - # labels_shape - op.inputs[1].get_shape().merge_with(tensor_shape.vector(batch_size)) - return [tensor_shape.vector(batch_size.value), input_shape] - - -@ops.RegisterShape("SoftmaxCrossEntropyWithLogits") -def _SoftmaxCrossEntropyWithLogitsShape(op): - """Shape function for SoftmaxCrossEntropyWithLogits op.""" - logits_shape = op.inputs[0].get_shape() - labels_shape = op.inputs[1].get_shape() - input_shape = logits_shape.merge_with(labels_shape).with_rank(2) - batch_size = input_shape[0] - return [tensor_shape.vector(batch_size.value), input_shape] +ops.RegisterShape("SparseSoftmaxCrossEntropyWithLogits")( + common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SoftmaxCrossEntropyWithLogits")( + common_shapes.call_cpp_shape_fn) def avg_pool(value, ksize, strides, padding, data_format="NHWC", name=None): @@ -812,58 +797,22 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): name=name) -ops.RegisterShape("Relu")(common_shapes.unchanged_shape) -ops.RegisterShape("Relu6")(common_shapes.unchanged_shape) -ops.RegisterShape("Elu")(common_shapes.unchanged_shape) -ops.RegisterShape("Softplus")(common_shapes.unchanged_shape) -ops.RegisterShape("Softsign")(common_shapes.unchanged_shape) - - -@ops.RegisterShape("ReluGrad") -@ops.RegisterShape("Relu6Grad") -@ops.RegisterShape("EluGrad") -@ops.RegisterShape("SoftplusGrad") -@ops.RegisterShape("SoftsignGrad") -def _BinaryElementwiseShape(op): - """Returns same shape as both inputs to op. - - Args: - op: Input operation. - - Returns: - Shape of both inputs to `op`. - """ - return [op.inputs[0].get_shape().merge_with(op.inputs[1].get_shape())] - - -ops.RegisterShape("L2Loss")(common_shapes.scalar_shape) - -ops.RegisterShape("LRN")(common_shapes.unchanged_shape_with_rank(4)) - - -@ops.RegisterShape("LRNGrad") -def _LRNGradShape(op): - """Shape function for LRNGrad op.""" - in_grads_shape = op.inputs[0].get_shape().with_rank(4) - in_image_shape = op.inputs[1].get_shape().with_rank(4) - out_image_shape = op.inputs[2].get_shape().with_rank(4) - return [in_grads_shape.merge_with(in_image_shape).merge_with(out_image_shape)] - - -ops.RegisterShape("Softmax")(common_shapes.unchanged_shape_with_rank_at_least( - 1)) - -ops.RegisterShape("LogSoftmax")( - common_shapes.unchanged_shape_with_rank_at_least(1)) - - -@ops.RegisterShape("InTopK") -def _InTopKShape(op): - """Shape function for InTopK op.""" - predictions_shape = op.inputs[0].get_shape().with_rank(2) - targets_shape = op.inputs[1].get_shape().with_rank(1) - batch_size = predictions_shape[0].merge_with(targets_shape[0]) - return [tensor_shape.vector(batch_size.value)] +ops.RegisterShape("Relu")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Relu6")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Elu")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Softplus")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Softsign")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReluGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Relu6Grad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("EluGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SoftplusGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("SoftsignGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("L2Loss")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("LRN")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn) @ops.RegisterShape("TopK") @@ -883,36 +832,10 @@ def _TopKShape(op): return [output_shape, output_shape] -@ops.RegisterShape("BatchNormWithGlobalNormalization") -def _BatchNormShape(op): - """Shape function for BatchNormWithGlobalNormalization op.""" - input_shape = op.inputs[0].get_shape().with_rank(4) - mean_shape = op.inputs[1].get_shape().with_rank(1) - var_shape = op.inputs[2].get_shape().with_rank(1) - beta_shape = op.inputs[3].get_shape().with_rank(1) - gamma_shape = op.inputs[4].get_shape().with_rank(1) - mean_shape[0].merge_with(input_shape[3]) - var_shape[0].merge_with(input_shape[3]) - beta_shape[0].merge_with(input_shape[3]) - gamma_shape[0].merge_with(input_shape[3]) - return [input_shape] - - -@ops.RegisterShape("BatchNormWithGlobalNormalizationGrad") -def _BatchNormGradShape(op): - """Shape function for BatchNormWithGlobalNormalizationGrad op.""" - input_shape = op.inputs[0].get_shape().with_rank(4) - mean_shape = op.inputs[1].get_shape().with_rank(1) - var_shape = op.inputs[2].get_shape().with_rank(1) - beta_shape = op.inputs[3].get_shape().with_rank(1) - out_backprop_shape = op.inputs[4].get_shape().with_rank(4) - input_shape = input_shape.merge_with(out_backprop_shape) - vector_dim = input_shape[3] - vector_dim = vector_dim.merge_with(mean_shape[0]) - vector_dim = vector_dim.merge_with(var_shape[0]) - vector_dim = vector_dim.merge_with(beta_shape[0]) - return [input_shape] + ([tensor_shape.vector(vector_dim)] * 4) - +ops.RegisterShape("BatchNormWithGlobalNormalization")( + common_shapes.call_cpp_shape_fn) +ops.RegisterShape("BatchNormWithGlobalNormalizationGrad")( + common_shapes.call_cpp_shape_fn) ops.RegisterShape("Conv2D")(common_shapes.conv2d_shape) ops.RegisterShape("DepthwiseConv2dNative")( @@ -1022,12 +945,8 @@ def _DepthwiseConv2dNativeBackpropInputShape(op): return [tensor_shape.unknown_shape(ndims=4)] -@ops.RegisterShape("MaxPoolGrad") -@ops.RegisterShape("MaxPoolGradWithArgmax") -def _MaxPoolGradShape(op): - """Shape function for the MaxPoolGrad op.""" - orig_input_shape = op.inputs[0].get_shape().with_rank(4) - return [orig_input_shape] +ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MaxPoolGradWithArgmax")(common_shapes.call_cpp_shape_fn) @ops.RegisterStatistics("Conv2D", "flops") @@ -1150,46 +1069,12 @@ def _Pool3DShape(op): channels])] -@ops.RegisterShape("Conv3DBackpropFilter") -def _Conv3DBackpropFilterShape(op): - """Shape function for the Conv3DBackpropFilter op.""" - filter_shape = op.inputs[1].get_shape() - return [filter_shape.with_rank(5)] - - -@ops.RegisterShape("Conv3DBackpropInput") -def _Conv3DBackpropInputShape(op): - """Shape function for the Conv3DBackpropInput op.""" - input_shape = op.inputs[0].get_shape() - return [input_shape.with_rank(5)] - - -@ops.RegisterShape("Conv3DBackpropFilterV2") -def _Conv3DBackpropFilterShapeV2(op): - """Shape function for the Conv3DBackpropFilterV2 op.""" - filter_shape = tensor_util.constant_value(op.inputs[1]) - return [tensor_shape.TensorShape(filter_shape).with_rank(5)] - - -@ops.RegisterShape("Conv3DBackpropInputV2") -def _Conv3DBackpropInputShapeV2(op): - """Shape function for the Conv3DBackpropInputV2 op.""" - input_shape = tensor_util.constant_value(op.inputs[0]) - return [tensor_shape.TensorShape(input_shape).with_rank(5)] - - -@ops.RegisterShape("AvgPool3DGrad") -def _AvgPool3DGradShape(op): - """Shape function for the AvgPool3DGrad op.""" - orig_input_shape = tensor_util.constant_value(op.inputs[0]) - return [tensor_shape.TensorShape(orig_input_shape).with_rank(5)] - - -@ops.RegisterShape("MaxPool3DGrad") -def _MaxPool3DGradShape(op): - """Shape function for the MaxPoolGrad op.""" - orig_input_shape = op.inputs[0].get_shape().with_rank(5) - return [orig_input_shape] +ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Conv3DBackpropInputV2")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("AvgPool3DGrad")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("MaxPool3DGrad")(common_shapes.call_cpp_shape_fn) @ops.RegisterStatistics("BiasAdd", "flops") @@ -1431,16 +1316,8 @@ def _Dilation2DShape(op): return [tensor_shape.TensorShape(output_shape)] -@ops.RegisterShape("Dilation2DBackpropInput") -def _Dilation2DBackpropInputShape(op): - """Shape function for Dilation2DBackpropInput op.""" - return [op.inputs[0].get_shape()] - - -@ops.RegisterShape("Dilation2DBackpropFilter") -def _Dilation2DBackpropFilterShape(op): - """Shape function for Dilation2DBackpropFilter op.""" - return [op.inputs[1].get_shape()] +ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn) @ops.RegisterStatistics("Dilation2D", "flops") |