aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_set_diag_op.h
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-06-03 18:22:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-03 18:26:15 -0700
commit7ffc3573255ecdc27b63c6d7dfefa8345225593d (patch)
treebdb68064054858fcb7656f2a8b6bc04281464558 /tensorflow/core/kernels/matrix_set_diag_op.h
parentaad2e3daff8fcd29ed8e5071d4c37a7f94a0421c (diff)
Add support for bools in matrix_diag, matrix_diag_part, matrix_set_diag, matrix_band_part.
PiperOrigin-RevId: 157939272
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.h')
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.h17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h
index 8ba2f3756a..63e5650bf0 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.h
+++ b/tensorflow/core/kernels/matrix_set_diag_op.h
@@ -71,6 +71,23 @@ struct MatrixSetDiag {
}
};
+template <typename Device>
+struct MatrixSetDiag<Device, bool> {
+ EIGEN_ALWAYS_INLINE static void Compute(const Device& d,
+ TTypes<bool, 3>::ConstTensor input,
+ TTypes<bool, 2>::ConstTensor diag,
+ TTypes<bool>::Scalar scratch,
+ TTypes<bool, 3>::Tensor output) {
+ output.device(d) = input;
+ generator::OverwriteDiagGenerator<bool> generator(diag, output);
+ // Use all() to force the generation to aggregate to the scalar
+ // output scratch. This in turn forces each element of the
+ // generator to execute. The side effect of the execution is to
+ // update the diagonal components of output with diag.
+ scratch.device(d) = diag.generate(generator).all();
+ }
+};
+
} // namespace functor
} // namespace tensorflow