aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-06 12:02:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 12:06:12 -0800
commit91c75ecc66c630f541a2215844b2012b9f5e6df6 (patch)
treef53b18b1c7bd6d3d945d5ec2abab1095525d08e3
parent93aeebad51f29e6d90d091be6e28986079805d3a (diff)
Allow SparseSegmentReduction ops to have missing segment IDs.
PiperOrigin-RevId: 178131721
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt36
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt38
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt57
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc146
-rw-r--r--tensorflow/core/ops/math_ops.cc151
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py135
-rw-r--r--tensorflow/python/ops/math_grad.py25
-rw-r--r--tensorflow/python/ops/math_ops.py153
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt6
9 files changed, 702 insertions, 45 deletions
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
new file mode 100644
index 0000000000..d6e1054003
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
@@ -0,0 +1,36 @@
+op {
+ graph_op_name: "SparseSegmentMeanWithNumSegments"
+ in_arg {
+ name: "indices"
+ description: <<END
+A 1-D tensor. Has same rank as `segment_ids`.
+END
+ }
+ in_arg {
+ name: "segment_ids"
+ description: <<END
+A 1-D tensor. Values should be sorted and can be repeated.
+END
+ }
+ in_arg {
+ name: "num_segments"
+ description: <<END
+Should equal the number of distinct segment IDs.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Has same shape as data, except for dimension 0 which has size
+`num_segments`.
+END
+ }
+ summary: "Computes the mean along sparse segments of a tensor."
+ description: <<END
+Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
new file mode 100644
index 0000000000..9ba98b8191
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "SparseSegmentSqrtNWithNumSegments"
+ in_arg {
+ name: "indices"
+ description: <<END
+A 1-D tensor. Has same rank as `segment_ids`.
+END
+ }
+ in_arg {
+ name: "segment_ids"
+ description: <<END
+A 1-D tensor. Values should be sorted and can be repeated.
+END
+ }
+ in_arg {
+ name: "num_segments"
+ description: <<END
+Should equal the number of distinct segment IDs.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Has same shape as data, except for dimension 0 which
+has size `k`, the number of segments.
+END
+ }
+ summary: "Computes the sum along sparse segments of a tensor divided by the sqrt of N."
+ description: <<END
+N is the size of the segment being reduced.
+
+Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
new file mode 100644
index 0000000000..3aeaba38e9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
@@ -0,0 +1,57 @@
+op {
+ graph_op_name: "SparseSegmentSumWithNumSegments"
+ in_arg {
+ name: "indices"
+ description: <<END
+A 1-D tensor. Has same rank as `segment_ids`.
+END
+ }
+ in_arg {
+ name: "segment_ids"
+ description: <<END
+A 1-D tensor. Values should be sorted and can be repeated.
+END
+ }
+ in_arg {
+ name: "num_segments"
+ description: <<END
+Should equal the number of distinct segment IDs.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Has same shape as data, except for dimension 0 which
+has size `num_segments`.
+END
+ }
+ summary: "Computes the sum along sparse segments of a tensor."
+ description: <<END
+Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+tf.sparse_segment_sum_with_num_segments(
+ c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+# => [[0 0 0 0]
+# [0 0 0 0]
+# [0 0 0 0]]
+
+tf.sparse_segment_sum_with_num_segments(c,
+ tf.constant([0, 1]),
+ tf.constant([0, 2],
+ num_segments=4))
+# => [[ 1 2 3 4]
+# [ 0 0 0 0]
+# [-1 -2 -3 -4]
+# [ 0 0 0 0]]
+```
+END
+}
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 2334e50f1d..3ef1cd1e06 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -553,10 +553,11 @@ class SparseSegmentReductionOpBase : public OpKernel {
public:
explicit SparseSegmentReductionOpBase(OpKernelConstruction* context,
bool is_mean, bool is_sqrtn,
- T default_value)
+ bool has_num_segments, T default_value)
: OpKernel(context),
is_mean_(is_mean),
is_sqrtn_(is_sqrtn),
+ has_num_segments_(has_num_segments),
default_value_(default_value) {}
void Compute(OpKernelContext* context) override {
@@ -564,6 +565,19 @@ class SparseSegmentReductionOpBase : public OpKernel {
const Tensor& indices = context->input(1);
const Tensor& segment_ids = context->input(2);
+ Index output_rows = -1;
+ if (has_num_segments_) {
+ const Tensor& num_segments = context->input(3);
+
+ OP_REQUIRES(
+ context, num_segments.shape().dims() == 0,
+ errors::InvalidArgument("num_segments should be a scalar, not shape ",
+ num_segments.shape().DebugString()));
+ output_rows = internal::SubtleMustCopy(num_segments.scalar<int32>()());
+ OP_REQUIRES(context, output_rows >= 0,
+ errors::InvalidArgument("segment ids must be >= 0"));
+ }
+
OP_REQUIRES(context, TensorShapeUtils::IsVector(indices.shape()),
errors::InvalidArgument("indices should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsVector(segment_ids.shape()),
@@ -581,10 +595,17 @@ class SparseSegmentReductionOpBase : public OpKernel {
const auto segment_vec = segment_ids.vec<OutputRow>();
// Note that the current implementation assumes that segment_vec values are
// sorted.
- const OutputRow output_rows =
+ const OutputRow last_segment_id_plus_one =
num_indices > 0
? internal::SubtleMustCopy(segment_vec(num_indices - 1)) + 1
: 0;
+ if (has_num_segments_) {
+ OP_REQUIRES(
+ context, output_rows >= last_segment_id_plus_one,
+ errors::InvalidArgument("segment ids must be < num_segments"));
+ } else {
+ output_rows = last_segment_id_plus_one;
+ }
OP_REQUIRES(context, output_rows >= 0,
errors::InvalidArgument("segment ids must be >= 0"));
@@ -646,11 +667,20 @@ class SparseSegmentReductionOpBase : public OpKernel {
indices_vec(start + bad_offset), " out of range [0, ",
input_flat.dimension(0), ")"));
- if (end >= num_indices) break;
start = end;
++end;
uninitialized_index = out_index + 1;
out_index = next_index;
+ if (end > num_indices) break;
+ }
+
+ // Fill the gap at the end with the default value.
+ if (uninitialized_index < output_rows) {
+ Eigen::DSizes<Eigen::DenseIndex, 2> gap_slice_shape(
+ output_rows - uninitialized_index, num_col);
+ Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, Eigen::Unaligned>
+ gap_slice(&output_flat(uninitialized_index, 0), gap_slice_shape);
+ gap_slice.setConstant(default_value_);
}
}
@@ -786,6 +816,7 @@ class SparseSegmentReductionOpBase : public OpKernel {
const bool is_mean_;
const bool is_sqrtn_;
+ const bool has_num_segments_;
const T default_value_;
};
@@ -794,9 +825,20 @@ class SparseSegmentReductionMeanOp
: public SparseSegmentReductionOpBase<Device, T> {
public:
explicit SparseSegmentReductionMeanOp(OpKernelConstruction* context)
- : SparseSegmentReductionOpBase<Device, T>(context, true /*is_mean*/,
- false /*is_sqrtn*/,
- T(0) /* default_value */) {}
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, true /*is_mean*/, false /*is_sqrtn*/,
+ false /* has_num_segments */, T(0) /* default_value */) {}
+};
+
+template <typename Device, class T>
+class SparseSegmentReductionMeanWithNumSegmentsOp
+ : public SparseSegmentReductionOpBase<Device, T> {
+ public:
+ explicit SparseSegmentReductionMeanWithNumSegmentsOp(
+ OpKernelConstruction* context)
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, true /*is_mean*/, false /*is_sqrtn*/,
+ true /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
@@ -804,9 +846,20 @@ class SparseSegmentReductionSqrtNOp
: public SparseSegmentReductionOpBase<Device, T> {
public:
explicit SparseSegmentReductionSqrtNOp(OpKernelConstruction* context)
- : SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/,
- true /*is_sqrtn*/,
- T(0) /* default_value */) {}
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, false /*is_mean*/, true /*is_sqrtn*/,
+ false /* has_num_segments */, T(0) /* default_value */) {}
+};
+
+template <typename Device, class T>
+class SparseSegmentReductionSqrtNWithNumSegmentsOp
+ : public SparseSegmentReductionOpBase<Device, T> {
+ public:
+ explicit SparseSegmentReductionSqrtNWithNumSegmentsOp(
+ OpKernelConstruction* context)
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, false /*is_mean*/, true /*is_sqrtn*/,
+ true /* has_num_segments */, T(0) /* default_value */) {}
};
template <typename Device, class T>
@@ -814,37 +867,65 @@ class SparseSegmentReductionSumOp
: public SparseSegmentReductionOpBase<Device, T> {
public:
explicit SparseSegmentReductionSumOp(OpKernelConstruction* context)
- : SparseSegmentReductionOpBase<Device, T>(context, false /*is_mean*/,
- false /*is_sqrtn*/,
- T(0) /* default_value */) {}
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, false /*is_mean*/, false /*is_sqrtn*/,
+ false /* has_num_segments */, T(0) /* default_value */) {}
};
-#define REGISTER_CPU_SPARSE_KERNELS(type) \
- REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- SparseSegmentReductionSumOp<CPUDevice, type>);
+template <typename Device, class T>
+class SparseSegmentReductionSumWithNumSegmentsOp
+ : public SparseSegmentReductionOpBase<Device, T> {
+ public:
+ explicit SparseSegmentReductionSumWithNumSegmentsOp(
+ OpKernelConstruction* context)
+ : SparseSegmentReductionOpBase<Device, T>(
+ context, false /*is_mean*/, false /*is_sqrtn*/,
+ true /* has_num_segments */, T(0) /* default_value */) {}
+};
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("SparseSegmentSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionSumOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSegmentSumWithNumSegments") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionSumWithNumSegmentsOp<CPUDevice, type>);
TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_SPARSE_KERNELS);
#undef REGISTER_CPU_SPARSE_KERNELS
-#define REGISTER_CPU_SPARSE_KERNELS(type) \
- REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- SparseSegmentReductionMeanOp<CPUDevice, type>);
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("SparseSegmentMean") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionMeanOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSegmentMeanWithNumSegments") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionMeanWithNumSegmentsOp<CPUDevice, type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#undef REGISTER_CPU_SPARSE_KERNELS
-#define REGISTER_CPU_SPARSE_KERNELS(type) \
- REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .TypeConstraint<int32>("Tidx"), \
- SparseSegmentReductionSqrtNOp<CPUDevice, type>);
+#define REGISTER_CPU_SPARSE_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("SparseSegmentSqrtN") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionSqrtNOp<CPUDevice, type>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("SparseSegmentSqrtNWithNumSegments") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx"), \
+ SparseSegmentReductionSqrtNWithNumSegmentsOp<CPUDevice, type>);
REGISTER_CPU_SPARSE_KERNELS(float);
REGISTER_CPU_SPARSE_KERNELS(double);
#undef REGISTER_CPU_SPARSE_KERNELS
@@ -889,9 +970,10 @@ class SparseSegmentGradOpBase : public OpKernel {
// Note that similar to SparseSegmentMean, we assume that segment_vec is
// already sorted and has non-negative values.
- const SegmentId num_segments =
+ const SegmentId num_segments = input.dim_size(0);
+ const SegmentId last_segment_id_plus_one =
internal::SubtleMustCopy(segment_vec(N - 1)) + 1;
- OP_REQUIRES(context, input.dim_size(0) == num_segments,
+ OP_REQUIRES(context, last_segment_id_plus_one <= num_segments,
errors::InvalidArgument("Invalid number of segments"));
// Compute scaling factors for input.
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 45ebfa203b..8ea170ba14 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1632,6 +1632,45 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
+ ShapeHandle data_shape;
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
+
+ ShapeHandle indices_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
+
+ ShapeHandle segment_ids_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
+
+ ShapeHandle num_segments_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
+
+ // indices and segment_ids should merge cleanly.
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
+
+ ShapeHandle subshape;
+ TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
+
+ ShapeHandle out;
+ const Tensor* dim0 = c->input_tensor(3);
+ if (dim0 == nullptr) {
+ // We don't have the value at inference time, so the output
+ // shape is unknown.
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
+ subshape, &out));
+ } else {
+ auto dim0_value = dim0->scalar<int32>()();
+ if (dim0_value < 0) {
+ return errors::InvalidArgument(
+ "Cannot specify a negative value for num_segments");
+ }
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
+ }
+ c->set_output(0, out);
+ return Status::OK();
+}
+
Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
ShapeHandle s_data = c->input(0);
ShapeHandle s_segment_ids = c->input(1);
@@ -1890,6 +1929,7 @@ output: Has same shape as data, except for dimension 0 which
has size `num_segments`.
)doc");
+
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")
@@ -1938,6 +1978,56 @@ output: Has same shape as data, except for dimension 0 which
has size `k`, the number of segments.
)doc");
+REGISTER_OP("SparseSegmentSumWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor.
+
+Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+tf.sparse_segment_sum_with_num_segments(
+ c, tf.constant([0, 1]), tf.constant([0, 0]), num_segments=3)
+# => [[0 0 0 0]
+# [0 0 0 0]
+# [0 0 0 0]]
+
+tf.sparse_segment_sum_with_num_segments(c,
+ tf.constant([0, 1]),
+ tf.constant([0, 2],
+ num_segments=4))
+# => [[ 1 2 3 4]
+# [ 0 0 0 0]
+# [-1 -2 -3 -4]
+# [ 0 0 0 0]]
+```
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which
+ has size `num_segments`.
+)doc");
+
REGISTER_OP("SparseSegmentMean")
.Input("data: T")
.Input("indices: Tidx")
@@ -1964,6 +2054,35 @@ output: Has same shape as data, except for dimension 0 which
)doc");
+REGISTER_OP("SparseSegmentMeanWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the mean along sparse segments of a tensor.
+
+Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which has size
+ `num_segments`.
+)doc");
+
REGISTER_OP("SparseSegmentMeanGrad")
.Input("grad: T")
.Input("indices: Tidx")
@@ -2010,6 +2129,38 @@ output: Has same shape as data, except for dimension 0 which
)doc");
+REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
+ .Input("data: T")
+ .Input("indices: Tidx")
+ .Input("segment_ids: int32")
+ .Input("num_segments: Tnumsegments")
+ .Output("output: T")
+ .Attr("T: {float, double}")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
+ .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn)
+ .Doc(R"doc(
+Computes the sum along sparse segments of a tensor divided by the sqrt of N.
+
+N is the size of the segment being reduced.
+
+Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
+misisng, the `output` tensor at that position will be zeroed.
+
+Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
+segments.
+
+indices: A 1-D tensor. Has same rank as `segment_ids`.
+
+segment_ids: A 1-D tensor. Values should be sorted and can be repeated.
+
+num_segments: Should equal the number of distinct segment IDs.
+
+output: Has same shape as data, except for dimension 0 which
+ has size `k`, the number of segments.
+
+)doc");
+
REGISTER_OP("SparseSegmentSqrtNGrad")
.Input("grad: T")
.Input("indices: Tidx")
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index fd58cdb170..5a54f448d0 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -46,13 +46,13 @@ class SegmentReductionHelper(test.TestCase):
return constant_op.constant(
np_values, shape=input_shape, dtype=dtype), np_values
- def _segmentReduce(self, indices, x, op1, op2=None, num_out_rows=None):
+ def _segmentReduce(self, indices, x, op1, op2=None, num_segments=None):
if not x.size:
return np.array([])
indices = np.asarray(indices)
- if num_out_rows is None:
- num_out_rows = indices[-1] + 1
- output = [None] * num_out_rows
+ if num_segments is None:
+ num_segments = indices[-1] + 1
+ output = [None] * num_segments
slice_shape = x.shape[indices.ndim:]
x_flat = x.reshape((indices.size,) + slice_shape)
for i, index in enumerate(indices.ravel()):
@@ -259,7 +259,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_out_rows=num_segments)
+ indices, np_x, np.add, op2=None, num_segments=num_segments)
s = math_ops.unsorted_segment_sum(
data=tf_x, segment_ids=indices, num_segments=num_segments)
tf_ans = s.eval()
@@ -278,7 +278,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
num_segments_constant = constant_op.constant(
num_segments, dtype=dtype)
np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_out_rows=num_segments)
+ indices, np_x, np.add, op2=None, num_segments=num_segments)
s = math_ops.unsorted_segment_sum(
data=tf_x,
segment_ids=indices,
@@ -397,7 +397,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
with self.test_session(use_gpu=True):
tf_x, np_x = self._input(shape, dtype=dtype)
np_ans = self._segmentReduce(
- indices, np_x, np.add, op2=None, num_out_rows=num_segments)
+ indices, np_x, np.add, op2=None, num_segments=num_segments)
# Replace np_ans[8] with 0 for the value
np_ans[8:] = 0
# Replace 8 with -1 in indices
@@ -417,8 +417,15 @@ class SparseSegmentReductionHelper(SegmentReductionHelper):
return (constant_op.constant(
indices, dtype=dtypes_lib.int32), indices, a, b)
- def _sparseSegmentReduce(self, x, indices, segment_indices, op1, op2=None):
- return self._segmentReduce(segment_indices, x[indices], op1, op2)
+ def _sparseSegmentReduce(self,
+ x,
+ indices,
+ segment_indices,
+ op1,
+ op2=None,
+ num_segments=None):
+ return self._segmentReduce(
+ segment_indices, x[indices], op1, op2, num_segments=num_segments)
class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
@@ -475,6 +482,31 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
tf_ans = s.eval()
self.assertAllClose(np_ans, tf_ans)
+ def testWithNumSegments(self):
+ tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
+ ops_list = [(np.add, None, math_ops.sparse_segment_sum_with_num_segments),
+ (self._mean_cum_op, self._mean_reduce_op,
+ math_ops.sparse_segment_mean_with_num_segments)]
+ segment_indices = [0, 2, 2, 2]
+ tf_indices = [8, 3, 0, 9]
+ num_segments = 5
+ with self.test_session(use_gpu=False):
+ for np_op1, np_op2, tf_op in ops_list:
+ np_ans = self._sparseSegmentReduce(
+ np_x,
+ tf_indices,
+ segment_indices,
+ np_op1,
+ np_op2,
+ num_segments=num_segments)
+ s = tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+ tf_ans = s.eval()
+ self.assertAllClose(np_ans, tf_ans)
+
def testSegmentIdsGreaterThanZero(self):
tf_x, np_x = self._input([10, 4], dtype=dtypes_lib.float32)
ops_list = [(np.add, None, math_ops.sparse_segment_sum), (
@@ -583,6 +615,63 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
with self.assertRaisesOpError("segment ids must be >= 0"):
s.eval()
+ def testSegmentWithNumSegmentsValid(self):
+ # Baseline for the test*WithNumSegmentsInvalid* methods below.
+ tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
+ ops_list = [
+ math_ops.sparse_segment_sum_with_num_segments,
+ math_ops.sparse_segment_mean_with_num_segments,
+ ]
+ num_segments = 5
+ segment_indices = [0, 1, 3, 3]
+ tf_indices = [8, 3, 0, 9]
+ with self.test_session(use_gpu=False):
+ for tf_op in ops_list:
+ s = tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+ s.eval()
+
+ def testSegmentWithNumSegmentsInvalid1(self):
+ tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
+ ops_list = [
+ math_ops.sparse_segment_sum_with_num_segments,
+ math_ops.sparse_segment_mean_with_num_segments,
+ ]
+ num_segments = 5
+ segment_indices = [0, 1, 3, 5]
+ tf_indices = [8, 3, 0, 9]
+ with self.test_session(use_gpu=False):
+ for tf_op in ops_list:
+ s = tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+ with self.assertRaisesOpError("segment ids must be < num_segments"):
+ s.eval()
+
+ def testSegmentWithNumSegmentsInvalid2(self):
+ tf_x, _ = self._input([10, 4], dtype=dtypes_lib.float32)
+ ops_list = [
+ math_ops.sparse_segment_sum_with_num_segments,
+ math_ops.sparse_segment_mean_with_num_segments,
+ ]
+ num_segments = -2
+ segment_indices = [0, 1, 3, 3]
+ tf_indices = [8, 3, 0, 9]
+ with self.test_session(use_gpu=False):
+ for tf_op in ops_list:
+ with self.assertRaisesRegexp(
+ ValueError, "Cannot specify a negative value for num_segments"):
+ tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+
def testGradient(self):
shape = [10, 4]
@@ -601,6 +690,32 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
delta=1)
self.assertAllClose(jacob_t, jacob_n)
+ def testGradientWithEmptySegmentsAtEnd(self):
+ shape = [10, 4]
+
+ num_segments = 5
+ segment_indices = [0, 1, 2, 2]
+ num_indices = len(segment_indices)
+ for tf_op in [
+ math_ops.sparse_segment_sum_with_num_segments,
+ math_ops.sparse_segment_mean_with_num_segments,
+ ]:
+ with self.test_session():
+ tf_indices, _, tf_x, np_x = self._sparse_input(
+ shape, num_indices, dtype=dtypes_lib.float64)
+ s = tf_op(
+ data=tf_x,
+ indices=tf_indices,
+ segment_ids=segment_indices,
+ num_segments=num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s, [5, 4],
+ x_init_value=np_x.astype(np.double),
+ delta=1)
+ self.assertAllClose(jacob_t, jacob_n)
+
def testGradientValid(self):
# Baseline for the testGradient*Invalid* methods below.
tf_x, _ = self._input([3, 4], dtype=dtypes_lib.float32)
@@ -646,7 +761,7 @@ class SparseSegmentReductionOpTest(SparseSegmentReductionHelper):
ops_list = [
math_ops.sparse_segment_mean_grad, math_ops.sparse_segment_sqrt_n_grad
]
- segment_indices = [0, 1, 1, 1] # 2 segments
+ segment_indices = [0, 1, 1, 4] # 5 segments
tf_indices = [8, 3, 0, 9]
with self.test_session(use_gpu=False):
for tf_op in ops_list:
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 38fe093ba7..0239396ae3 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -184,6 +184,15 @@ def _SparseSegmentSumGrad(op, grad):
None)
+@ops.RegisterGradient("SparseSegmentSumWithNumSegments")
+def _SparseSegmentSumWithNumSegmentsGrad(op, grad):
+ """Gradient for SparseSegmentSumWithNumSegments."""
+ input_rows = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.unsorted_segment_sum(
+ array_ops.gather(grad, op.inputs[2]), op.inputs[1], input_rows), None,
+ None, None)
+
+
@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op, grad):
"""Gradient for SparseSegmentMean."""
@@ -192,6 +201,14 @@ def _SparseSegmentMeanGrad(op, grad):
dim0), None, None)
+@ops.RegisterGradient("SparseSegmentMeanWithNumSegments")
+def _SparseSegmentMeanWithNumSegmentsGrad(op, grad):
+ """Gradient for SparseSegmentMeanWithNumSegments."""
+ dim0 = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.sparse_segment_mean_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None, None)
+
+
@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op, grad):
"""Gradient for SparseSegmentSqrtN."""
@@ -200,6 +217,14 @@ def _SparseSegmentSqrtNGrad(op, grad):
dim0), None, None)
+@ops.RegisterGradient("SparseSegmentSqrtNWithNumSegments")
+def _SparseSegmentSqrtNWithNumSegmentsGrad(op, grad):
+ """Gradient for SparseSegmentSqrtNWithNumSegmnets."""
+ dim0 = array_ops.shape(op.inputs[0])[0]
+ return (math_ops.sparse_segment_sqrt_n_grad(grad, op.inputs[1], op.inputs[2],
+ dim0), None, None, None)
+
+
def _SegmentMinOrMaxGrad(op, grad, is_sorted):
"""Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code."""
zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f9538be6c9..6af36343d5 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -2495,6 +2495,159 @@ def reduced_shape(input_shape, axes):
]) # [1, 1]
+def sparse_segment_sum(data, indices, segment_ids, name=None,
+ num_segments=None):
+ r"""Computes the sum along sparse segments of a tensor.
+
+ Read @{$math_ops#segmentation$the section on segmentation} for an explanation
+ of segments.
+
+ Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
+ dimension, selecting a subset of dimension 0, specified by `indices`.
+ `segment_ids` is allowed to have missing ids, in which case the output will
+ be zeros at those indices. In those cases `num_segments` is used to determine
+ the size of the output.
+
+ For example:
+
+ ```python
+ c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+
+ # Select two rows, one segment.
+ tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+ # => [[0 0 0 0]]
+
+ # Select two rows, two segment.
+ tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+ # => [[ 1 2 3 4]
+ # [-1 -2 -3 -4]]
+
+ # With missing segment ids.
+ tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
+ num_segments=4)
+ # => [[ 1 2 3 4]
+ # [ 0 0 0 0]
+ # [-1 -2 -3 -4]
+ # [ 0 0 0 0]]
+
+ # Select all rows, two segments.
+ tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+ # => [[0 0 0 0]
+ # [5 6 7 8]]
+
+ # Which is equivalent to:
+ tf.segment_sum(c, tf.constant([0, 0, 1]))
+ ```
+
+ Args:
+ data: A `Tensor` with data that will be assembled in the output.
+ indices: A 1-D `Tensor` with indices into `data`. Has same rank as
+ `segment_ids`.
+ segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
+ Values should be sorted and can be repeated.
+ name: A name for the operation (optional).
+ num_segments: An optional int32 scalar. Indicates the size of the output
+ `Tensor`.
+
+ Returns:
+ A `tensor` of the shape as data, except for dimension 0 which
+ has size `k`, the number of segments specified via `num_segments` or
+ inferred for the last element in `segments_ids`.
+ """
+ if num_segments is not None:
+ return gen_math_ops.sparse_segment_sum_with_num_segments(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=name)
+ else:
+ return gen_math_ops.sparse_segment_sum(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ name=name)
+
+
+def sparse_segment_mean(data, indices, segment_ids, name=None,
+ num_segments=None):
+ r"""Computes the mean along sparse segments of a tensor.
+
+ Read @{$math_ops#segmentation$the section on segmentation} for an explanation
+ of segments.
+
+ Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
+ dimension, selecting a subset of dimension 0, specified by `indices`.
+ `segment_ids` is allowed to have missing ids, in which case the output will
+ be zeros at those indices. In those cases `num_segments` is used to determine
+ the size of the output.
+
+ Args:
+ data: A `Tensor` with data that will be assembled in the output.
+ indices: A 1-D `Tensor` with indices into `data`. Has same rank as
+ `segment_ids`.
+ segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
+ Values should be sorted and can be repeated.
+ name: A name for the operation (optional).
+ num_segments: An optional int32 scalar. Indicates the size of the output
+ `Tensor`.
+
+ Returns:
+ A `tensor` of the shape as data, except for dimension 0 which
+ has size `k`, the number of segments specified via `num_segments` or
+ inferred for the last element in `segments_ids`.
+ """
+ if num_segments is not None:
+ return gen_math_ops.sparse_segment_mean_with_num_segments(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=name)
+ else:
+ return gen_math_ops.sparse_segment_mean(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ name=name)
+
+
+def sparse_segment_sqrt_n(data, indices, segment_ids, name=None,
+ num_segments=None):
+ r"""Computes the sum along sparse segments of a tensor divided by the sqrt(N).
+
+ `N` is the size of the segment being reduced.
+
+ Args:
+ data: A `Tensor` with data that will be assembled in the output.
+ indices: A 1-D `Tensor` with indices into `data`. Has same rank as
+ `segment_ids`.
+ segment_ids: A 1-D `Tensor` with indices into the output `Tensor`.
+ Values should be sorted and can be repeated.
+ name: A name for the operation (optional).
+ num_segments: An optional int32 scalar. Indicates the size of the output
+ `Tensor`.
+
+ Returns:
+ A `tensor` of the shape as data, except for dimension 0 which
+ has size `k`, the number of segments specified via `num_segments` or
+ inferred for the last element in `segments_ids`.
+ """
+ if num_segments is not None:
+ return gen_math_ops.sparse_segment_sqrt_n_with_num_segments(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ num_segments=num_segments,
+ name=name)
+ else:
+ return gen_math_ops.sparse_segment_sqrt_n(
+ data=data,
+ indices=indices,
+ segment_ids=segment_ids,
+ name=name)
+
+
def tensordot(a, b, axes, name=None):
r"""Tensor contraction of a and b along specified axes.
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index b12cf5a864..4b33aa218c 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1842,15 +1842,15 @@ tf_module {
}
member_method {
name: "sparse_segment_mean"
- argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "sparse_segment_sqrt_n"
- argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "sparse_segment_sum"
- argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "sparse_slice"