aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/linalg_ops_common.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/linalg_ops_common.cc')
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.cc46
1 files changed, 22 insertions, 24 deletions
diff --git a/tensorflow/core/kernels/linalg_ops_common.cc b/tensorflow/core/kernels/linalg_ops_common.cc
index 93342a7a24..df8047564b 100644
--- a/tensorflow/core/kernels/linalg_ops_common.cc
+++ b/tensorflow/core/kernels/linalg_ops_common.cc
@@ -2,29 +2,23 @@
namespace tensorflow {
-void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
+void UnaryLinearAlgebraOpBase::Compute(OpKernelContext* context) {
const Tensor& in = context->input(0);
- const int input_rank = GetInputMatrixRank();
- OP_REQUIRES(
- context, input_rank == 2,
- errors::InvalidArgument("Only matrix inputs are supported so far."));
+ const int input_rank = in.dims();
if (SupportsBatchOperation()) {
- OP_REQUIRES(context, in.dims() > input_rank,
- errors::InvalidArgument("Input tensor must have rank >= %d",
- input_rank + 1));
+ OP_REQUIRES(context, input_rank >= 2,
+ errors::InvalidArgument("Input tensor must have rank >= 2"));
} else {
- OP_REQUIRES(context, in.dims() == input_rank,
- errors::InvalidArgument("Input tensor must have rank == %d",
- input_rank));
+ OP_REQUIRES(context, input_rank == 2,
+ errors::InvalidArgument("Input tensor must have rank == 2"));
}
// If the tensor rank is greater than input_rank, we consider the inner-most
// dimensions as matrices, and loop over all the other outer
// dimensions to compute the results.
- // TODO(kalakris): Only matrix inputs are currently supported.
- const int row_dimension = in.dims() - 2;
- const int col_dimension = in.dims() - 1;
+ const int row_dimension = input_rank - 2;
+ const int col_dimension = input_rank - 1;
const int64 num_rows = in.dim_size(row_dimension);
const int64 num_cols = in.dim_size(col_dimension);
const TensorShape input_matrix_shape = TensorShape({num_rows, num_cols});
@@ -36,16 +30,19 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
int num_matrices = 1;
// The output has the shape of all the outer dimensions of the input
// except for the last two, plus the output_matrix_shape (if the output
- // is not scalar). This still assumes that each input matrix is
- // 2-dimensional, in accordance with the TODO above.
+ // is not scalar). This assumes that each input matrix is
+ // 2-dimensional.
TensorShape output_shape;
- if (in.dims() == 2) {
+ if (input_rank == 2) {
output_shape = output_matrix_shape;
} else {
- for (int dim = 0; dim <= in.dims() - 3; ++dim) {
+ // Add the common outer dimensions.
+ for (int dim = 0; dim < input_rank - 2; ++dim) {
num_matrices *= in.dim_size(dim);
output_shape.AddDim(in.dim_size(dim));
}
+ // Add the inner dimensions that depend on the operation implemented by the
+ // derived class.
for (int dim = 0; dim < output_matrix_shape.dims(); ++dim) {
output_shape.AddDim(output_matrix_shape.dim_size(dim));
}
@@ -68,7 +65,7 @@ void LinearAlgebraOpBase::Compute(OpKernelContext* context) {
}
template <typename Scalar, bool SupportsBatchOperationT>
-void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
+void UnaryLinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
OpKernelContext* context, int64 matrix_index, const Tensor& in,
const TensorShape& input_matrix_shape, Tensor* out,
const TensorShape& output_matrix_shape) {
@@ -90,10 +87,11 @@ void LinearAlgebraOp<Scalar, SupportsBatchOperationT>::ComputeMatrix(
ComputeMatrix(context, input, &output);
}
-// Explicitly instantiate LinearAlgebraOp for the scalar types we expect to use.
-template class LinearAlgebraOp<float, false>;
-template class LinearAlgebraOp<float, true>;
-template class LinearAlgebraOp<double, false>;
-template class LinearAlgebraOp<double, true>;
+// Explicitly instantiate UnaryLinearAlgebraOp for the scalar types we expect to
+// use.
+template class UnaryLinearAlgebraOp<float, false>;
+template class UnaryLinearAlgebraOp<float, true>;
+template class UnaryLinearAlgebraOp<double, false>;
+template class UnaryLinearAlgebraOp<double, true>;
} // namespace tensorflow