diff options
author | 2015-12-11 11:29:48 -0800 | |
---|---|---|
committer | 2015-12-11 11:29:48 -0800 | |
commit | 0a21a38d4ef5b66177f407f74f14dd7b72232b36 (patch) | |
tree | 0f83d13d946ef1b42576c067a30ef7da87874a71 /tensorflow/core/kernels/sparse_split_op.cc | |
parent | bc624aa8d9460dca794fde6d5534f1d3e8054016 (diff) |
TensorFlow: merge changes from internal
Change 110010103
Implementing SparseSplitOp.
The op takes a sparse tensor (list, values and shape), split_dim and num_splits and produces a list of num_splits tensors where the shape of each tensor is the shape of the original tensor except split_dim = shape[split_dim +num_split - 1 / num_split]. in case if shape[split_dim] is not an integer multiple of num_split an extra one dimension get added to the slices starting from 0.
For example if the input shape is a [2, 10] split_dim = 1, num_split = 3
output shapes will be [[2, 4], [2, 4], [2, 2]].
The Op register shape to [Unknown, dim] for indices tensors and [Unknown] for the values tensor because shape can't be inferred without evaluate input tensors.
Base CL: 110012853
Diffstat (limited to 'tensorflow/core/kernels/sparse_split_op.cc')
-rw-r--r-- | tensorflow/core/kernels/sparse_split_op.cc | 80 |
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc new file mode 100644 index 0000000000..f935375b8d --- /dev/null +++ b/tensorflow/core/kernels/sparse_split_op.cc @@ -0,0 +1,80 @@ +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/util/sparse/sparse_tensor.h" + +namespace tensorflow { + +template <typename T> +class SparseSplitOp : public OpKernel { + public: + explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_)); + } + + void Compute(OpKernelContext* context) override { + const int32 split_dim = context->input(0).scalar<int>()(); + const Tensor& input_indices = context->input(1); + const Tensor& input_values = context->input(2); + const Tensor& input_shape = context->input(3); + + OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()), + errors::InvalidArgument( + "Input indices should be a matrix but recived shape ", + input_indices.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()), + errors::InvalidArgument( + "Input values should be a vector but received shape ", + input_indices.shape().ShortDebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape.shape()), + errors::InvalidArgument( + "Input shape should be a vector but received shape ", + input_shape.shape().ShortDebugString())); + + OP_REQUIRES(context, input_shape.dim_size(0) && + split_dim < input_shape.vec<int64>().size(), + errors::InvalidArgument( + "Input split_dim should be between 0 and rank (", + input_shape.vec<int64>().size(), "), got ", split_dim)); + + OP_REQUIRES(context, num_split_ >= 1 && + num_split_ <= input_shape.vec<int64>()(split_dim), + errors::InvalidArgument("Input num_split should be between 1 " + "and the splitting dimension size (", + input_shape.vec<int64>()(split_dim), + "), got ", num_split_)); + + sparse::SparseTensor sparse_tensor(input_indices, input_values, + TensorShape(input_shape.vec<int64>())); + const std::vector<sparse::SparseTensor> outputs = + sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_); + + for (int slice_index = 0; slice_index < num_split_; ++slice_index) { + context->set_output(slice_index, outputs[slice_index].indices()); + context->set_output(slice_index + num_split_, + outputs[slice_index].values()); + Tensor* shape = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output( + slice_index + 2 * num_split_, + {outputs[slice_index].shape().dims()}, &shape)); + for (int dim = 0; dim < outputs[slice_index].shape().dims(); ++dim) { + shape->vec<int64>()(dim) = outputs[slice_index].shape().dim_size(dim); + } + } + } + + private: + int num_split_; +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ + SparseSplitOp<type>) + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +} // namespace tensorflow |