diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-06-03 18:22:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-03 18:26:15 -0700 |
commit | 7ffc3573255ecdc27b63c6d7dfefa8345225593d (patch) | |
tree | bdb68064054858fcb7656f2a8b6bc04281464558 /tensorflow/core/kernels/matrix_set_diag_op.h | |
parent | aad2e3daff8fcd29ed8e5071d4c37a7f94a0421c (diff) |
Add support for bools in matrix_diag, matrix_diag_part, matrix_set_diag, matrix_band_part.
PiperOrigin-RevId: 157939272
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.h')
-rw-r--r-- | tensorflow/core/kernels/matrix_set_diag_op.h | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h index 8ba2f3756a..63e5650bf0 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.h +++ b/tensorflow/core/kernels/matrix_set_diag_op.h @@ -71,6 +71,23 @@ struct MatrixSetDiag { } }; +template <typename Device> +struct MatrixSetDiag<Device, bool> { + EIGEN_ALWAYS_INLINE static void Compute(const Device& d, + TTypes<bool, 3>::ConstTensor input, + TTypes<bool, 2>::ConstTensor diag, + TTypes<bool>::Scalar scratch, + TTypes<bool, 3>::Tensor output) { + output.device(d) = input; + generator::OverwriteDiagGenerator<bool> generator(diag, output); + // Use all() to force the generation to aggregate to the scalar + // output scratch. This in turn forces each element of the + // generator to execute. The side effect of the execution is to + // update the diagonal components of output with diag. + scratch.device(d) = diag.generate(generator).all(); + } +}; + } // namespace functor } // namespace tensorflow |