aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/BUILD7
-rw-r--r--tensorflow/core/kernels/matrix_exponential_op.cc59
-rw-r--r--tensorflow/core/ops/linalg_ops.cc27
-rw-r--r--tensorflow/python/kernel_tests/BUILD12
-rw-r--r--tensorflow/python/kernel_tests/matrix_exponential_op_test.py196
-rw-r--r--tensorflow/python/ops/hidden_ops.txt1
-rw-r--r--tensorflow/python/ops/linalg/linalg_impl.py1
-rw-r--r--tensorflow/python/ops/math_ops.py1
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt4
-rw-r--r--third_party/eigen.BUILD1
-rw-r--r--third_party/eigen3/BUILD1
-rw-r--r--third_party/eigen3/unsupported/Eigen/MatrixFunctions1
12 files changed, 311 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1cb7c97be4..34cd51ba66 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2255,6 +2255,7 @@ cc_library(
":cholesky_grad",
":cholesky_op",
":determinant_op",
+ ":matrix_exponential_op",
":matrix_inverse_op",
":matrix_solve_ls_op",
":matrix_solve_op",
@@ -2322,6 +2323,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "matrix_exponential_op",
+ prefix = "matrix_exponential_op",
+ deps = LINALG_DEPS,
+)
+
+tf_kernel_library(
name = "self_adjoint_eig_op",
prefix = "self_adjoint_eig_op",
deps = LINALG_DEPS,
diff --git a/tensorflow/core/kernels/matrix_exponential_op.cc b/tensorflow/core/kernels/matrix_exponential_op.cc
new file mode 100644
index 0000000000..4cc3f32f7e
--- /dev/null
+++ b/tensorflow/core/kernels/matrix_exponential_op.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/linalg_ops.cc.
+
+#include "third_party/eigen3/Eigen/Core"
+#include "third_party/eigen3/unsupported/Eigen/MatrixFunctions"
+#include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/linalg_ops_common.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
+
+
+namespace tensorflow {
+
+template <class Scalar>
+class MatrixExponentialOp : public LinearAlgebraOp<Scalar> {
+ public:
+ INHERIT_LINALG_TYPEDEFS(Scalar);
+
+ explicit MatrixExponentialOp(OpKernelConstruction* context) : Base(context) {}
+
+ void ComputeMatrix(OpKernelContext* context, const ConstMatrixMaps& inputs,
+ MatrixMaps* outputs) final {
+ const ConstMatrixMap& input = inputs[0];
+ if (input.rows() == 0) return;
+ using Matrix = Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
+ Matrix tmp = input;
+ outputs->at(0) = tmp.exp();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(MatrixExponentialOp);
+};
+
+REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<float>), float);
+REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<double>), double);
+REGISTER_LINALG_OP("MatrixExponential",
+ (MatrixExponentialOp<complex64>), complex64);
+REGISTER_LINALG_OP("MatrixExponential",
+ (MatrixExponentialOp<complex128>), complex128);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 4851619f83..53e2360d23 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -282,6 +282,33 @@ Equivalent to np.linalg.inv
@end_compatibility
)doc");
+REGISTER_OP("MatrixExponential")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("T: {double, float, complex64, complex128}")
+ .SetShapeFn(BatchUnchangedSquareShapeFn)
+ .Doc(R"doc(
+Computes the matrix exponential of one or more square matrices:
+
+exp(A) = \sum_{n=0}^\infty A^n/n!
+
+The exponential is computed using a combination of the scaling and squaring
+method and the Pade approximation. Details can be founds in:
+Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
+revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor of the same shape as the input
+containing the exponential for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+
+@compatibility(scipy)
+Equivalent to scipy.linalg.expm
+@end_compatibility
+)doc");
+
REGISTER_OP("Cholesky")
.Input("input: T")
.Output("output: T")
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index edfeeae8b4..7fa504e85e 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -505,6 +505,18 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "matrix_exponential_op_test",
+ size = "small",
+ srcs = ["matrix_exponential_op_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:linalg_ops",
+ ],
+)
+
cuda_py_test(
name = "matrix_inverse_op_test",
size = "small",
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
new file mode 100644
index 0000000000..c5a7a3ba99
--- /dev/null
+++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py
@@ -0,0 +1,196 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.ops.gen_linalg_ops.matrix_exponential."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import math
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_linalg_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def np_expm(x):
+ """Slow but accurate Taylor series matrix exponential."""
+ y = np.zeros(x.shape, dtype=x.dtype)
+ xn = np.eye(x.shape[0], dtype=x.dtype)
+ for n in range(40):
+ y += xn / float(math.factorial(n))
+ xn = np.dot(xn, x)
+ return y
+
+
+class ExponentialOpTest(test.TestCase):
+
+ def _verifyExponential(self, x, np_type):
+ # TODO(pfau): add matrix logarithm and test that it is inverse of expm.
+ inp = x.astype(np_type)
+ with self.test_session(use_gpu=True):
+ # Verify that x^{-1} * x == Identity matrix.
+ tf_ans = gen_linalg_ops._matrix_exponential(inp)
+ if x.size == 0:
+ np_ans = np.empty(x.shape, dtype=np_type)
+ else:
+ if x.ndim > 2:
+ np_ans = np.zeros(inp.shape, dtype=np_type)
+ for i in itertools.product(*[range(x) for x in inp.shape[:-2]]):
+ np_ans[i] = np_expm(inp[i])
+ else:
+ np_ans = np_expm(inp)
+ out = tf_ans.eval()
+ self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-3)
+
+ def _verifyExponentialReal(self, x):
+ for np_type in [np.float32, np.float64]:
+ self._verifyExponential(x, np_type)
+
+ def _verifyExponentialComplex(self, x):
+ for np_type in [np.complex64, np.complex128]:
+ self._verifyExponential(x, np_type)
+
+ def _makeBatch(self, matrix1, matrix2):
+ matrix_batch = np.concatenate(
+ [np.expand_dims(matrix1, 0),
+ np.expand_dims(matrix2, 0)])
+ matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
+ return matrix_batch
+
+ def testNonsymmetric(self):
+ # 2x2 matrices
+ matrix1 = np.array([[1., 2.], [3., 4.]])
+ matrix2 = np.array([[1., 3.], [3., 5.]])
+ self._verifyExponentialReal(matrix1)
+ self._verifyExponentialReal(matrix2)
+ # A multidimensional batch of 2x2 matrices
+ self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
+ # Complex
+ matrix1 = matrix1.astype(np.complex64)
+ matrix1 += 1j * matrix1
+ matrix2 = matrix2.astype(np.complex64)
+ matrix2 += 1j * matrix2
+ self._verifyExponentialComplex(matrix1)
+ self._verifyExponentialComplex(matrix2)
+ # Complex batch
+ self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2))
+
+ def testSymmetricPositiveDefinite(self):
+ # 2x2 matrices
+ matrix1 = np.array([[2., 1.], [1., 2.]])
+ matrix2 = np.array([[3., -1.], [-1., 3.]])
+ self._verifyExponentialReal(matrix1)
+ self._verifyExponentialReal(matrix2)
+ # A multidimensional batch of 2x2 matrices
+ self._verifyExponentialReal(self._makeBatch(matrix1, matrix2))
+ # Complex
+ matrix1 = matrix1.astype(np.complex64)
+ matrix1 += 1j * matrix1
+ matrix2 = matrix2.astype(np.complex64)
+ matrix2 += 1j * matrix2
+ self._verifyExponentialComplex(matrix1)
+ self._verifyExponentialComplex(matrix2)
+ # Complex batch
+ self._verifyExponentialComplex(self._makeBatch(matrix1, matrix2))
+
+ def testNonSquareMatrix(self):
+ # When the exponential of a non-square matrix is attempted we should return
+ # an error
+ with self.assertRaises(ValueError):
+ gen_linalg_ops._matrix_exponential(np.array([[1., 2., 3.], [3., 4., 5.]]))
+
+ def testWrongDimensions(self):
+ # The input to the inverse should be at least a 2-dimensional tensor.
+ tensor3 = constant_op.constant([1., 2.])
+ with self.assertRaises(ValueError):
+ gen_linalg_ops._matrix_exponential(tensor3)
+
+ def testEmpty(self):
+ self._verifyExponentialReal(np.empty([0, 2, 2]))
+ self._verifyExponentialReal(np.empty([2, 0, 0]))
+
+ def testRandomSmallAndLarge(self):
+ np.random.seed(42)
+ for dtype in np.float32, np.float64, np.complex64, np.complex128:
+ for batch_dims in [(), (1,), (3,), (2, 2)]:
+ for size in 8, 31, 32:
+ shape = batch_dims + (size, size)
+ matrix = np.random.uniform(
+ low=-1.0, high=1.0,
+ size=np.prod(shape)).reshape(shape).astype(dtype)
+ self._verifyExponentialReal(matrix)
+
+ def testConcurrentExecutesWithoutError(self):
+ with self.test_session(use_gpu=True) as sess:
+ matrix1 = random_ops.random_normal([5, 5], seed=42)
+ matrix2 = random_ops.random_normal([5, 5], seed=42)
+ expm1 = gen_linalg_ops._matrix_exponential(matrix1)
+ expm2 = gen_linalg_ops._matrix_exponential(matrix2)
+ expm = sess.run([expm1, expm2])
+ self.assertAllEqual(expm[0], expm[1])
+
+
+class MatrixExponentialBenchmark(test.Benchmark):
+
+ shapes = [
+ (4, 4),
+ (10, 10),
+ (16, 16),
+ (101, 101),
+ (256, 256),
+ (1000, 1000),
+ (1024, 1024),
+ (2048, 2048),
+ (513, 4, 4),
+ (513, 16, 16),
+ (513, 256, 256),
+ ]
+
+ def _GenerateMatrix(self, shape):
+ batch_shape = shape[:-2]
+ shape = shape[-2:]
+ assert shape[0] == shape[1]
+ n = shape[0]
+ matrix = np.ones(shape).astype(np.float32) / (
+ 2.0 * n) + np.diag(np.ones(n).astype(np.float32))
+ return variables.Variable(np.tile(matrix, batch_shape + (1, 1)))
+
+ def benchmarkMatrixExponentialOp(self):
+ for shape in self.shapes:
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/cpu:0"):
+ matrix = self._GenerateMatrix(shape)
+ expm = gen_linalg_ops._matrix_exponential(matrix)
+ variables.global_variables_initializer().run()
+ self.run_op_benchmark(
+ sess,
+ control_flow_ops.group(expm),
+ min_iters=25,
+ name="matrix_exponential_cpu_{shape}".format(
+ shape=shape))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt
index 732ab8f15a..a0fff9e16c 100644
--- a/tensorflow/python/ops/hidden_ops.txt
+++ b/tensorflow/python/ops/hidden_ops.txt
@@ -223,6 +223,7 @@ BatchSelfAdjointEig
BatchSelfAdjointEigV2
BatchSvd
LogMatrixDeterminant
+MatrixExponential
MatrixSolveLs
SelfAdjointEig
SelfAdjointEigV2
diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py
index 04a15e3e5b..bf15f0e2e5 100644
--- a/tensorflow/python/ops/linalg/linalg_impl.py
+++ b/tensorflow/python/ops/linalg/linalg_impl.py
@@ -38,6 +38,7 @@ diag_part = array_ops.matrix_diag_part
eigh = linalg_ops.self_adjoint_eig
eigvalsh = linalg_ops.self_adjoint_eigvals
einsum = special_math_ops.einsum
+expm = gen_linalg_ops._matrix_exponential
eye = linalg_ops.eye
inv = linalg_ops.matrix_inverse
lstsq = linalg_ops.matrix_solve_ls
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index d38abb5eb9..886b2048f9 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -89,6 +89,7 @@ See the @{$python/math_ops} guide.
@@matrix_inverse
@@cholesky
@@cholesky_solve
+@@matrix_exponential
@@matrix_solve
@@matrix_triangular_solve
@@matrix_solve_ls
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 0d62585ff4..9fd38a29b7 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -73,6 +73,10 @@ tf_module {
argspec: "args=[\'equation\'], varargs=inputs, keywords=kwargs, defaults=None"
}
member_method {
+ name: "expm"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
diff --git a/third_party/eigen.BUILD b/third_party/eigen.BUILD
index 0157cbcddf..07bb6645eb 100644
--- a/third_party/eigen.BUILD
+++ b/third_party/eigen.BUILD
@@ -36,6 +36,7 @@ EIGEN_FILES = [
"unsupported/Eigen/src/KroneckerProduct/**",
"unsupported/Eigen/MatrixFunctions",
"unsupported/Eigen/SpecialFunctions",
+ "unsupported/Eigen/src/MatrixFunctions/**",
"unsupported/Eigen/src/SpecialFunctions/**",
]
diff --git a/third_party/eigen3/BUILD b/third_party/eigen3/BUILD
index ad87477b7a..f5f3418527 100644
--- a/third_party/eigen3/BUILD
+++ b/third_party/eigen3/BUILD
@@ -26,6 +26,7 @@ cc_library(
"Eigen/Eigenvalues",
"Eigen/QR",
"Eigen/SVD",
+ "unsupported/Eigen/MatrixFunctions",
"unsupported/Eigen/SpecialFunctions",
"unsupported/Eigen/CXX11/ThreadPool",
"unsupported/Eigen/CXX11/Tensor",
diff --git a/third_party/eigen3/unsupported/Eigen/MatrixFunctions b/third_party/eigen3/unsupported/Eigen/MatrixFunctions
new file mode 100644
index 0000000000..314b325f8c
--- /dev/null
+++ b/third_party/eigen3/unsupported/Eigen/MatrixFunctions
@@ -0,0 +1 @@
+#include "unsupported/Eigen/MatrixFunctions"