aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_set_diag_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-29 12:28:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-29 13:35:30 -0700
commit2a5a96976deba6b6bac007f237bd2665c0da184e (patch)
tree9305e22b8c933f69d024fe5eaeac922079b8c013 /tensorflow/core/kernels/matrix_set_diag_op.cc
parentb7d5df182b7394ab17c11ccc949ce07812920bd9 (diff)
Update trace, matrix_set_diag, matrix_diag_part and their gradients to work for rectangular matrices.
Generalize trace to work like numpy.trace(x, axis1=-2, axis2=-1), including for rank > 2. Fix bad doc string for matrix_band_part. Change: 134700928
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.cc')
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.cc27
1 files changed, 12 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index 2b8a61121c..952da7d8df 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -59,21 +59,18 @@ class MatrixSetDiagOp : public OpKernel {
"input must be at least 2-dim, received shape: ",
input.shape().DebugString()));
- // Check to make sure the last two dimensions have the same value
- const int64 k = input_shape.dim_size(rank - 1);
- OP_REQUIRES(
- context, k == input_shape.dim_size(rank - 2),
- errors::InvalidArgument(
- "input's last two dimensions must be equal, received shape: ",
- input.shape().DebugString()));
-
- TensorShape input_shape_but_one = input_shape;
- input_shape_but_one.RemoveDim(rank - 1);
-
- OP_REQUIRES(context, input_shape_but_one == diag_shape,
+ // Check to make sure the last dimension of diag is equal to the smaller of
+ // the last two dimensions of input.
+ const int64 min_dim = std::min(input_shape.dim_size(rank - 1),
+ input_shape.dim_size(rank - 2));
+ TensorShape expected_diag_shape = input_shape;
+ expected_diag_shape.RemoveDim(rank - 1);
+ expected_diag_shape.RemoveDim(rank - 2);
+ expected_diag_shape.AddDim(min_dim);
+ OP_REQUIRES(context, expected_diag_shape == diag_shape,
errors::InvalidArgument(
- "must have diagonal.shape == input.shape[:-1], but "
- "received input shape: ",
+ "must have diagonal.shape == input.shape[:-2] + "
+ "min(input.shape[-2:]), but received input shape: ",
input_shape.DebugString(), " and diagonal shape: ",
diag_shape.DebugString()));
@@ -127,7 +124,7 @@ struct MatrixSetDiag<CPUDevice, T> {
typename TTypes<T, 3>::Tensor output) {
output.device(d) = input;
for (int64 r = 0; r < output.dimension(0); ++r) {
- for (int64 d = 0; d < output.dimension(1); ++d) {
+ for (int64 d = 0; d < diag.dimension(1); ++d) {
output(r, d, d) = diag(r, d);
}
}