aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/linalg_ops_common.h
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-11-18 10:47:35 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-11-18 10:47:35 -0800
commitab34d55ce7618e52069a2e1c9e51aac5a1ea81c3 (patch)
tree9c79427b45ff6501e8374ceb7b4fc3bdb2828e15 /tensorflow/core/kernels/linalg_ops_common.h
parent9eb88d56ab6a9a361662d73a258593d8fbf10b62 (diff)
TensorFlow: more features, performance improvements, and doc fixes.
Changes: - Add Split/Concat() methods to TensorUtil (meant for convenience, not speed) by Chris. - Changes to linear algebra ops interface by Rasmus - Tests for tensorboard by Daniel - Fix bug in histogram calculation by Cassandra - Added tool for backwards compatibility of OpDefs. Tool Checks in history of opdefs and their changes, checks for backwards-incompatible changes. All done by @josh11b - Fix some protobuf example proto docs by Oliver - Add derivative of MatrixDeterminant by @yaroslavvb - Add a priority queue queue by @ebrevdo - Doc and typo fixes by Aurelien and @dave-andersen - Speed improvements to ConvBackwardFilter by @andydavis - Improve speed of Alexnet on TitanX by @zheng-xq - Add some host memory annotations to some GPU kernels by Yuan. - Add support for doubles in histogram summary by @jmchen-g Base CL: 108158338
Diffstat (limited to 'tensorflow/core/kernels/linalg_ops_common.h')
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.h44
1 files changed, 19 insertions, 25 deletions
diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h
index 80d485a8aa..b0ec94d90b 100644
--- a/tensorflow/core/kernels/linalg_ops_common.h
+++ b/tensorflow/core/kernels/linalg_ops_common.h
@@ -1,6 +1,10 @@
#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+// Classes to support linear algebra functionality, similar to the numpy.linalg
+// module. Supports batch computation on several matrices at once, sharding the
+// computations across different threads if necessary.
+
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -16,21 +20,12 @@
namespace tensorflow {
-// A base class to support linear algebra functionality, similar to the
-// numpy.linalg module. Supports batch computation on several matrices at once,
-// sharding the computations across different threads if necessary.
-//
-// TODO(kalakris): This needs to be expanded to support binary inputs, and
-// multiple outputs.
-class LinearAlgebraOpBase : public OpKernel {
+// Base class for unary linear algebra operators.
+class UnaryLinearAlgebraOpBase : public OpKernel {
public:
- explicit LinearAlgebraOpBase(OpKernelConstruction* context)
+ explicit UnaryLinearAlgebraOpBase(OpKernelConstruction* context)
: OpKernel(context) {}
- ~LinearAlgebraOpBase() override {}
-
- // Return the expected rank of the input.
- // TODO(kalakris): This should be a virtual function to support vector inputs.
- int GetInputMatrixRank() { return 2; }
+ ~UnaryLinearAlgebraOpBase() override {}
// Return the output shape of each individual matrix operation. Must be
// rank 0, 1, or 2. Scalar outputs are rank 0.
@@ -62,7 +57,8 @@ class LinearAlgebraOpBase : public OpKernel {
// address
// out->flat<Scalar>().data() +
// matrix_index * output_matrix_shape.num_elements().
- // The LinearAlgebraOp<Scalar> class below has functionality which performs
+ // The UnaryLinearAlgebraOp<Scalar> class below has functionality which
+ // performs
// this mapping and presents an interface based on the Eigen::MatrixBase API.
virtual void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in,
@@ -72,8 +68,6 @@ class LinearAlgebraOpBase : public OpKernel {
void Compute(OpKernelContext* context) override;
};
-// A base class for linear algebra ops templated on the scalar type.
-//
// This base class encapsulates the functionality of mapping the input and
// output tensors using Eigen::Map, so that the Eigen::MatrixBase API may be
// directly used by derived classes.
@@ -81,10 +75,10 @@ class LinearAlgebraOpBase : public OpKernel {
// will allow the Op to process batches of matrices (rank >= 3); if set to
// false the Op will only accept rank 2 inputs.
template <typename Scalar, bool SupportsBatchOperationT>
-class LinearAlgebraOp : public LinearAlgebraOpBase {
+class UnaryLinearAlgebraOp : public UnaryLinearAlgebraOpBase {
public:
- explicit LinearAlgebraOp(OpKernelConstruction* context)
- : LinearAlgebraOpBase(context) {}
+ explicit UnaryLinearAlgebraOp(OpKernelConstruction* context)
+ : UnaryLinearAlgebraOpBase(context) {}
using Matrix =
Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
@@ -100,18 +94,18 @@ class LinearAlgebraOp : public LinearAlgebraOpBase {
bool SupportsBatchOperation() final { return SupportsBatchOperationT; }
- // A concrete implementation of LinearAlgebraOpBase::ComputeMatrix().
+ // A concrete implementation of UnaryLinearAlgebraOpBase::ComputeMatrix().
void ComputeMatrix(OpKernelContext* context, int64 matrix_index,
const Tensor& in, const TensorShape& input_matrix_shape,
Tensor* out, const TensorShape& output_matrix_shape) final;
};
-// Declare that LinearAlgebraOp is explicitly instantiated in
+// Declare that UnaryLinearAlgebraOp is explicitly instantiated in
// linalg_ops_common.cc for float and double.
-extern template class LinearAlgebraOp<float, false>;
-extern template class LinearAlgebraOp<float, true>;
-extern template class LinearAlgebraOp<double, false>;
-extern template class LinearAlgebraOp<double, true>;
+extern template class UnaryLinearAlgebraOp<float, false>;
+extern template class UnaryLinearAlgebraOp<float, true>;
+extern template class UnaryLinearAlgebraOp<double, false>;
+extern template class UnaryLinearAlgebraOp<double, true>;
} // namespace tensorflow