1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/port.h"
#include "tensorflow/core/public/tensor.h"
#include "tensorflow/core/public/tensor_shape.h"
#include "tensorflow/core/util/work_sharder.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
namespace tensorflow {
// A base class to support linear algebra functionality, similar to the
// numpy.linalg module. Supports batch computation on several matrices at once,
// sharding the computations across different threads if necessary.
//
// TODO(kalakris): This needs to be expanded to support binary inputs, and
// multiple outputs.
class LinearAlgebraOpBase : public OpKernel {
public:
explicit LinearAlgebraOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
~LinearAlgebraOpBase() override {}
// Return the expected rank of the input.
// TODO(kalakris): This should be a virtual function to support vector inputs.
int GetInputMatrixRank() { return 2; }
// Return the output shape of each individual matrix operation. Must be
// rank 0, 1, or 2. Scalar outputs are rank 0.
virtual TensorShape GetOutputMatrixShape(
const TensorShape& input_matrix_shape) = 0;
// Return the cost per matrix operation. Cost per unit is assumed to be
// roughly 1ns, based on comments in core/util/work_sharder.cc.
virtual int64 GetCostPerUnit(const TensorShape& input_matrix_shape) = 0;
// If SupportsBatchOperation() returns false, this Op will only accept rank 2
// (if the supported input type is a matrix). If it returns true, the Op will
// accept inputs of rank >= 3, and repeatedly execute the operation on all
// matrices in the innermost two dimensions.
virtual bool SupportsBatchOperation() = 0;
// Perform the actual computation on an input matrix, and store the results
// in the output. This will be called repeatedly for a single call to
// Compute(), if multiple matrices exist in the input Tensor.
//
// This function should only compute the results for a single input matrix.
// The 'matrix_index' parameter specifies the index of the matrix to be used
// from the input, and the index of the matrix to be written to in the output.
// The input matrix is in row major order, and is located at the memory
// address
// in.flat<Scalar>().data() +
// matrix_index * input_matrix_shape.num_elements().
// The output matrix is in row major order, and is located at the memory
// address
// out->flat<Scalar>().data() +
// matrix_index * output_matrix_shape.num_elements().
// The LinearAlgebraOp<Scalar> class below has functionality which performs
// this mapping and presents an interface based on the Eigen::MatrixBase API.
virtual void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in,
const TensorShape& input_matrix_shape, Tensor* out,
const TensorShape& output_matrix_shape) = 0;
void Compute(OpKernelContext* context) override;
};
// A base class for linear algebra ops templated on the scalar type.
//
// This base class encapsulates the functionality of mapping the input and
// output tensors using Eigen::Map, so that the Eigen::MatrixBase API may be
// directly used by derived classes.
// SupportsBatchOperationT is a bool template argument which if set to true
// will allow the Op to process batches of matrices (rank >= 3); if set to
// false the Op will only accept rank 2 inputs.
template <typename Scalar, bool SupportsBatchOperationT>
class LinearAlgebraOp : public LinearAlgebraOpBase {
public:
explicit LinearAlgebraOp(OpKernelConstruction* context)
: LinearAlgebraOpBase(context) {}
using ConstMatrixMap =
Eigen::Map<const Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>>;
using MatrixMap = Eigen::Map<
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
// Perform the actual computation on the input matrix, and store the results
// in the output. This will be called repeatedly for a single call to
// Compute(), if multiple matrices exist in the input Tensor.
virtual void ComputeMatrix(OpKernelContext* context,
const ConstMatrixMap& input,
MatrixMap* output) = 0;
bool SupportsBatchOperation() final { return SupportsBatchOperationT; }
// A concrete implementation of LinearAlgebraOpBase::ComputeMatrix().
void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in, const TensorShape& input_matrix_shape,
Tensor* out, const TensorShape& output_matrix_shape) final;
};
// Declare that LinearAlgebraOp is explicitly instantiated in
// linalg_ops_common.cc for float and double.
extern template class LinearAlgebraOp<float, false>;
extern template class LinearAlgebraOp<float, true>;
extern template class LinearAlgebraOp<double, false>;
extern template class LinearAlgebraOp<double, true>;
} // namespace tensorflow
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_KERNEL_BUILDER( \
Name(OpName).Device(DEVICE_CPU).TypeConstraint<Scalar>("T"), OpClass)
#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
|