diff options
Diffstat (limited to 'tensorflow/core/kernels/sparse_slice_grad_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_slice_grad_op.cc | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse_slice_grad_op.cc b/tensorflow/core/kernels/sparse_slice_grad_op.cc new file mode 100644 index 0000000000..90a39ed818 --- /dev/null +++ b/tensorflow/core/kernels/sparse_slice_grad_op.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +template <typename T> +class SparseSliceGradOp : public OpKernel { + public: + explicit SparseSliceGradOp(OpKernelConstruction *ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext *ctx) override { + const Tensor *backprop_val_grad, *input_indices, *output_indices, *input_start; + OP_REQUIRES_OK(ctx, ctx->input("backprop_val_grad", &backprop_val_grad)); + OP_REQUIRES_OK(ctx, ctx->input("input_indices", &input_indices)); + OP_REQUIRES_OK(ctx, ctx->input("input_start", &input_start)); + OP_REQUIRES_OK(ctx, ctx->input("output_indices", &output_indices)); + + OP_REQUIRES(ctx, + TensorShapeUtils::IsMatrix(input_indices->shape()) && + TensorShapeUtils::IsMatrix(output_indices->shape()), + errors::InvalidArgument( + "Input and output indices should be matrices " + "but received shapes: ", + input_indices->shape().DebugString(), " and ", + output_indices->shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVector(backprop_val_grad->shape()), + errors::InvalidArgument( + "Input backprop_val_grad should be a vector but received shape: ", + backprop_val_grad->shape().DebugString())); + OP_REQUIRES( + ctx, + input_indices->dim_size(1) == output_indices->dim_size(1), + errors::InvalidArgument("The input and output should have the same " + "ndims: got: ", input_indices->dim_size(1), " and ", + output_indices->dim_size(1))); + OP_REQUIRES( + ctx, output_indices->dim_size(0) <= input_indices->dim_size(0), + errors::InvalidArgument("# rows of output_indices should be not greater " + "than of input_indices, got ", + output_indices->dim_size(0), " and ", + input_indices->dim_size(0))); + OP_REQUIRES( + ctx, backprop_val_grad->NumElements() == output_indices->dim_size(0), + errors::InvalidArgument("# elements of backprop_val_grad and # rows of " + "output_indices should match (#nnz of sum): got ", + backprop_val_grad->NumElements(), " and ", + output_indices->dim_size(0))); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(input_start->shape()), + errors::InvalidArgument( + "The input_start should be a vector but received shape ", + input_start->shape().DebugString())); + + const int num_dims = input_indices->dim_size(1); + OP_REQUIRES(ctx, num_dims == input_start->NumElements(), + errors::InvalidArgument( + "Expected input_start to be a vector of length ", num_dims, + " but got length ", input_start->NumElements())); + + const int64 input_nnz = input_indices->dim_size(0); + + Tensor *val_grad; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({input_nnz}), &val_grad)); + + T *val_grad_flat = val_grad->flat<T>().data(); + const T *backprop_val_grad_flat = backprop_val_grad->flat<T>().data(); + memset(val_grad_flat, 0, sizeof(T) * input_nnz); + + // Fill gradients for position where indices of input and output are same. + const auto input_indices_mat = input_indices->matrix<int64>(); + const auto output_indices_mat = output_indices->matrix<int64>(); + const auto input_start_flat = input_start->flat<int64>(); + int64 j = 0; + for (int64 i = 0; i < input_nnz && j < backprop_val_grad->NumElements(); + ++i) { + bool is_same = true; + for (int d = 0; d < num_dims; ++d) { + const int64 a = input_indices_mat(i, d); + const int64 b = output_indices_mat(j, d); + const int64 offset = input_start_flat(d); + if (a != b + offset) { + is_same = false; + break; + } + } + if (is_same) { + val_grad_flat[i] = backprop_val_grad_flat[j]; + ++j; + } + } + OP_REQUIRES( + ctx, backprop_val_grad->NumElements() == j, + errors::Internal("Elements of backprop_val_grad aren't all propagated. " + "Num elements:", backprop_val_grad->NumElements(), + ", used: ", j)); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSliceGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SparseSliceGradOp<type>) + +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS +} // namespace tensorflow |