diff options
Diffstat (limited to 'tensorflow/core/kernels/linalg_ops_common.cc')
-rw-r--r-- | tensorflow/core/kernels/linalg_ops_common.cc | 46 |
1 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc index 93342a7a24..df8047564b 100644 --- a/tensorflow/core/kernels/linalg_ops_common.cc +++ b/tensorflow/core/kernels/linalg_ops_common.cc @@ -2,29 +2,23 @@ namespace tensorflow { -void LinearAlgebraOpBase::Compute(OpKernelContext* context) { +void UnaryLinearAlgebraOpBase::Compute(OpKernelContext* context) { const Tensor& in = context->input(0); - const int input_rank = GetInputMatrixRank(); - OP_REQUIRES( - context, input_rank == 2, - errors::InvalidArgument("Only matrix inputs are supported so far.")); + const int input_rank = in.dims(); if (SupportsBatchOperation()) { - OP_REQUIRES(context, in.dims() > input_rank, - errors::InvalidArgument("Input tensor must have rank >= %d", - input_rank + 1)); + OP_REQUIRES(context, input_rank >= 2, + errors::InvalidArgument("Input tensor must have rank >= 2")); } else { - OP_REQUIRES(context, in.dims() == input_rank, - errors::InvalidArgument("Input tensor must have rank == %d", - input_rank)); + OP_REQUIRES(context, input_rank == 2, + errors::InvalidArgument("Input tensor must have rank == 2")); } // If the tensor rank is greater than input_rank, we consider the inner-most // dimensions as matrices, and loop over all the other outer // dimensions to compute the results. - // TODO(kalakris): Only matrix inputs are currently supported. - const int row_dimension = in.dims() - 2; - const int col_dimension = in.dims() - 1; + const int row_dimension = input_rank - 2; + const int col_dimension = input_rank - 1; const int64 num_rows = in.dim_size(row_dimension); const int64 num_cols = in.dim_size(col_dimension); const TensorShape input_matrix_shape = TensorShape({num_rows, num_cols}); @@ -36,16 +30,19 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) { int num_matrices = 1; // The output has the shape of all the outer dimensions of the input // except for the last two, plus the output_matrix_shape (if the output - // is not scalar). This still assumes that each input matrix is - // 2-dimensional, in accordance with the TODO above. + // is not scalar). This assumes that each input matrix is + // 2-dimensional. TensorShape output_shape; - if (in.dims() == 2) { + if (input_rank == 2) { output_shape = output_matrix_shape; } else { - for (int dim = 0; dim <= in.dims() - 3; ++dim) { + // Add the common outer dimensions. + for (int dim = 0; dim < input_rank - 2; ++dim) { num_matrices *= in.dim_size(dim); output_shape.AddDim(in.dim_size(dim)); } + // Add the inner dimensions that depend on the operation implemented by the + // derived class. for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) { output_shape.AddDim(output_matrix_shape.dim_size(dim)); } @@ -68,7 +65,7 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) { } template <typename Scalar, bool SupportsBatchOperationT> -void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix( +void UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix( OpKernelContext* context, int64 matrix_index, const Tensor& in, const TensorShape& input_matrix_shape, Tensor* out, const TensorShape& output_matrix_shape) { @@ -90,10 +87,11 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix( ComputeMatrix(context, input, &output); } -// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use. -template class LinearAlgebraOp<float, false>; -template class LinearAlgebraOp<float, true>; -template class LinearAlgebraOp<double, false>; -template class LinearAlgebraOp<double, true>; +// Explicitly instantiate UnaryLinearAlgebraOp for the scalar types we expect to +// use. +template class UnaryLinearAlgebraOp<float, false>; +template class UnaryLinearAlgebraOp<float, true>; +template class UnaryLinearAlgebraOp<double, false>; +template class UnaryLinearAlgebraOp<double, true>; } // namespace tensorflow |