diff options
author | 2016-09-29 12:28:50 -0800 | |
---|---|---|
committer | 2016-09-29 13:35:30 -0700 | |
commit | 2a5a96976deba6b6bac007f237bd2665c0da184e (patch) | |
tree | 9305e22b8c933f69d024fe5eaeac922079b8c013 /tensorflow/core/kernels/matrix_set_diag_op.cc | |
parent | b7d5df182b7394ab17c11ccc949ce07812920bd9 (diff) |
Update trace, matrix_set_diag, matrix_diag_part and their gradients to work for rectangular matrices.
Generalize trace to work like numpy.trace(x, axis1=-2, axis2=-1), including for rank > 2.
Fix bad doc string for matrix_band_part.
Change: 134700928
Diffstat (limited to 'tensorflow/core/kernels/matrix_set_diag_op.cc')
-rw-r--r-- | tensorflow/core/kernels/matrix_set_diag_op.cc | 27 |
1 files changed, 12 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc index 2b8a61121c..952da7d8df 100644 --- a/tensorflow/core/kernels/matrix_set_diag_op.cc +++ b/tensorflow/core/kernels/matrix_set_diag_op.cc @@ -59,21 +59,18 @@ class MatrixSetDiagOp : public OpKernel { "input must be at least 2-dim, received shape: ", input.shape().DebugString())); - // Check to make sure the last two dimensions have the same value - const int64 k = input_shape.dim_size(rank - 1); - OP_REQUIRES( - context, k == input_shape.dim_size(rank - 2), - errors::InvalidArgument( - "input's last two dimensions must be equal, received shape: ", - input.shape().DebugString())); - - TensorShape input_shape_but_one = input_shape; - input_shape_but_one.RemoveDim(rank - 1); - - OP_REQUIRES(context, input_shape_but_one == diag_shape, + // Check to make sure the last dimension of diag is equal to the smaller of + // the last two dimensions of input. + const int64 min_dim = std::min(input_shape.dim_size(rank - 1), + input_shape.dim_size(rank - 2)); + TensorShape expected_diag_shape = input_shape; + expected_diag_shape.RemoveDim(rank - 1); + expected_diag_shape.RemoveDim(rank - 2); + expected_diag_shape.AddDim(min_dim); + OP_REQUIRES(context, expected_diag_shape == diag_shape, errors::InvalidArgument( - "must have diagonal.shape == input.shape[:-1], but " - "received input shape: ", + "must have diagonal.shape == input.shape[:-2] + " + "min(input.shape[-2:]), but received input shape: ", input_shape.DebugString(), " and diagonal shape: ", diag_shape.DebugString())); @@ -127,7 +124,7 @@ struct MatrixSetDiag<CPUDevice, T> { typename TTypes<T, 3>::Tensor output) { output.device(d) = input; for (int64 r = 0; r < output.dimension(0); ++r) { - for (int64 d = 0; d < output.dimension(1); ++d) { + for (int64 d = 0; d < diag.dimension(1); ++d) { output(r, d, d) = diag(r, d); } } |