diff options
author | 2017-10-24 10:39:16 -0700 | |
---|---|---|
committer | 2017-10-24 10:42:59 -0700 | |
commit | 1bbec9e4e9c5d3fbbc2fa2b58841435e86dbf76a (patch) | |
tree | bd1f15556bd8157f6e012c12d5477851eb827c9a /tensorflow/core/kernels/determinant_op.h | |
parent | 377dd3d0d51f93f22eadfd18f4186c27d8506d69 (diff) |
* Add GPU implementation of LogDeterminant op.
* Switch GPU implementation of Determinant to use the more numerically stable kernel as well.
* Change behavior for Determinant on matrices with (numerically) infinite determinants to match the behavior of numpy.linalg.det: Return inf for matrix with infinite determinant.
* Misc. cleanup in code working around missing support for complex in the NVCC compiler.
PiperOrigin-RevId: 173277377
Diffstat (limited to 'tensorflow/core/kernels/determinant_op.h')
-rw-r--r-- | tensorflow/core/kernels/determinant_op.h | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/determinant_op.h b/tensorflow/core/kernels/determinant_op.h new file mode 100644 index 0000000000..e931e328e4 --- /dev/null +++ b/tensorflow/core/kernels/determinant_op.h @@ -0,0 +1,47 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ + +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace functor { + +// Helper functor to compute Determinant from a partially pivoted LU +// factorization. +template <typename Device, typename Scalar> +struct DeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes<Scalar, 3>::ConstTensor lu_factor, + const int* pivots, typename TTypes<Scalar, 1>::Tensor output, + int* info); +}; + +// Helper functor to compute sign and log of the absolute value of the +// determinant from a partially pivoted LU factorization. +template <typename Device, typename Scalar> +struct LogDeterminantFromPivotedLUFunctor { + void operator()(const Device& device, + typename TTypes<Scalar, 3>::ConstTensor lu_factor, + const int* pivots, typename TTypes<Scalar, 1>::Tensor sign, + typename TTypes<Scalar, 1>::Tensor log_abs_det); +}; + +} // namespace functor +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_DETERMINANT_OP_H_ |