aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/sparse_split_op.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-11 11:29:48 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-11 11:29:48 -0800
commit0a21a38d4ef5b66177f407f74f14dd7b72232b36 (patch)
tree0f83d13d946ef1b42576c067a30ef7da87874a71 /tensorflow/core/kernels/sparse_split_op.cc
parentbc624aa8d9460dca794fde6d5534f1d3e8054016 (diff)
TensorFlow: merge changes from internal
Change 110010103 Implementing SparseSplitOp. The op takes a sparse tensor (list, values and shape), split_dim and num_splits and produces a list of num_splits tensors where the shape of each tensor is the shape of the original tensor except split_dim = shape[split_dim +num_split - 1 / num_split]. in case if shape[split_dim] is not an integer multiple of num_split an extra one dimension get added to the slices starting from 0. For example if the input shape is a [2, 10] split_dim = 1, num_split = 3 output shapes will be [[2, 4], [2, 4], [2, 2]]. The Op register shape to [Unknown, dim] for indices tensors and [Unknown] for the values tensor because shape can't be inferred without evaluate input tensors. Base CL: 110012853
Diffstat (limited to 'tensorflow/core/kernels/sparse_split_op.cc')
-rw-r--r--tensorflow/core/kernels/sparse_split_op.cc80
1 files changed, 80 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/sparse_split_op.cc b/tensorflow/core/kernels/sparse_split_op.cc
new file mode 100644
index 0000000000..f935375b8d
--- /dev/null
+++ b/tensorflow/core/kernels/sparse_split_op.cc
@@ -0,0 +1,80 @@
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/util/sparse/sparse_tensor.h"
+
+namespace tensorflow {
+
+template <typename T>
+class SparseSplitOp : public OpKernel {
+ public:
+ explicit SparseSplitOp(OpKernelConstruction* context) : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("num_split", &num_split_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const int32 split_dim = context->input(0).scalar<int>()();
+ const Tensor& input_indices = context->input(1);
+ const Tensor& input_values = context->input(2);
+ const Tensor& input_shape = context->input(3);
+
+ OP_REQUIRES(context, TensorShapeUtils::IsMatrix(input_indices.shape()),
+ errors::InvalidArgument(
+ "Input indices should be a matrix but recived shape ",
+ input_indices.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_values.shape()),
+ errors::InvalidArgument(
+ "Input values should be a vector but received shape ",
+ input_indices.shape().ShortDebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsVector(input_shape.shape()),
+ errors::InvalidArgument(
+ "Input shape should be a vector but received shape ",
+ input_shape.shape().ShortDebugString()));
+
+ OP_REQUIRES(context, input_shape.dim_size(0) &&
+ split_dim < input_shape.vec<int64>().size(),
+ errors::InvalidArgument(
+ "Input split_dim should be between 0 and rank (",
+ input_shape.vec<int64>().size(), "), got ", split_dim));
+
+ OP_REQUIRES(context, num_split_ >= 1 &&
+ num_split_ <= input_shape.vec<int64>()(split_dim),
+ errors::InvalidArgument("Input num_split should be between 1 "
+ "and the splitting dimension size (",
+ input_shape.vec<int64>()(split_dim),
+ "), got ", num_split_));
+
+ sparse::SparseTensor sparse_tensor(input_indices, input_values,
+ TensorShape(input_shape.vec<int64>()));
+ const std::vector<sparse::SparseTensor> outputs =
+ sparse::SparseTensor::Split<T>(sparse_tensor, split_dim, num_split_);
+
+ for (int slice_index = 0; slice_index < num_split_; ++slice_index) {
+ context->set_output(slice_index, outputs[slice_index].indices());
+ context->set_output(slice_index + num_split_,
+ outputs[slice_index].values());
+ Tensor* shape = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(
+ slice_index + 2 * num_split_,
+ {outputs[slice_index].shape().dims()}, &shape));
+ for (int dim = 0; dim < outputs[slice_index].shape().dims(); ++dim) {
+ shape->vec<int64>()(dim) = outputs[slice_index].shape().dim_size(dim);
+ }
+ }
+ }
+
+ private:
+ int num_split_;
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSplit").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
+ SparseSplitOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+} // namespace tensorflow