aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Nikolaas Steenbergen <nikolaas.steenbergen@googlemail.com>2017-02-04 00:15:21 +0100
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2017-02-03 15:15:21 -0800
commit94f2229c9e4b4324a324330c8f419276eda7e503 (patch)
treed458b69c476dfd42a214032f2d1e277fa4cd0d18 /tensorflow
parent084b37a00f3cf2cc89d433528ca63ec1d3b5b313 (diff)
TF-549 Adds unsorted segment max Op (#6975)
* TF-549 Adds unsorted segment max Op * Cosmetic change * Add todo comment
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc113
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h30
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc4
-rw-r--r--tensorflow/core/ops/math_grad.cc1
-rw-r--r--tensorflow/core/ops/math_ops.cc98
-rw-r--r--tensorflow/core/ops/ops.pbtxt53
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py32
-rw-r--r--tensorflow/python/ops/math_grad.py30
-rw-r--r--tensorflow/python/ops/math_ops.py1
9 files changed, 297 insertions, 65 deletions
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index fee16cdb78..5bd4362801 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -220,13 +220,15 @@ REGISTER_COMPLEX_CPU_KERNELS_ALL(complex128);
namespace functor {
// UnsortedSegmentSumFunctor implementation for CPUDevice.
+// todo: Remove duplicate code in UnsortedSegmentSumFunctor and UnsortedSegmentMaxFunctor.
template <typename T, typename Index>
-struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
+struct UnsortedSegmentSumFunctor<CPUDevice, T, Index>
+ : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
- typename TTypes<T, 2>::Tensor output) {
+ typename TTypes<T, 2>::Tensor output) override {
output.setZero();
if (data_size == 0) {
return;
@@ -243,16 +245,44 @@ struct UnsortedSegmentSumFunctor<CPUDevice, T, Index> {
}
}
};
-
+// UnsortedSegmentMaxFunctor implementation for CPUDevice.
+template <typename T, typename Index>
+struct UnsortedSegmentMaxFunctor<CPUDevice, T, Index>
+ : UnsortedSegmentBaseFunctor<CPUDevice, T, Index> {
+ void operator()(OpKernelContext* ctx, const CPUDevice& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output) override {
+ output.setConstant(std::numeric_limits<T>::min());
+ if (data_size == 0) {
+ return;
+ }
+ const int64 N = segment_ids.dimension(0);
+ auto data_flat = typename TTypes<T, 2>::ConstTensor(data, N, data_size / N);
+ for (int64 i = 0; i < N; ++i) {
+ Index j = internal::SubtleMustCopy(segment_ids(i));
+ OP_REQUIRES(ctx, FastBoundsCheck(j, output_rows),
+ errors::InvalidArgument(
+ "segment_ids", SliceDebugString(segment_ids_shape, i),
+ " = ", j, " is out of range [0, ", output_rows, ")"));
+ output.template chip<0>(j) =
+ data_flat.template chip<0>(i).cwiseMax(output.template chip<0>(j));
+ }
+ }
+};
} // namespace functor
-// Similar to SegmentReductionOp but can handle unsorted segment definitions and
-// specifying size of output.
+// Base class for SegmentReductionOps that can handle unsorted segment
+// definitions
+// and specifying the size of the output in addition to a reduction function
template <typename Device, class T, class Index>
-class UnsortedSegmentSumOp : public OpKernel {
+class UnsortedSegmentBaseOp : public OpKernel {
public:
- explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
- : OpKernel(context) {}
+ explicit UnsortedSegmentBaseOp(
+ OpKernelConstruction* context,
+ functor::UnsortedSegmentBaseFunctor<Device, T, Index>& functor)
+ : OpKernel(context), reduction_functor_(functor) {}
void Compute(OpKernelContext* context) override {
const Tensor& data = context->input(0);
@@ -288,27 +318,70 @@ class UnsortedSegmentSumOp : public OpKernel {
auto output_flat = output->flat_outer_dims<T>();
auto data_ptr = data.template flat<T>().data();
- functor::UnsortedSegmentSumFunctor<Device, T, Index>()(
- context, context->template eigen_device<Device>(), output_rows,
- segment_ids.shape(), segment_flat, data.NumElements(), data_ptr,
- output_flat);
+ reduction_functor_(context, context->template eigen_device<Device>(),
+ output_rows, segment_ids.shape(), segment_flat,
+ data.NumElements(), data_ptr, output_flat);
}
+ private:
+ functor::UnsortedSegmentBaseFunctor<Device, T, Index>& reduction_functor_;
};
-#define REGISTER_CPU_UNSORTED_KERNELS(type, index_type) \
+template <typename Device, class T, class Index>
+class UnsortedSegmentSumOp : public UnsortedSegmentBaseOp<Device, T, Index> {
+ public:
+ explicit UnsortedSegmentSumOp(OpKernelConstruction* context)
+ : UnsortedSegmentBaseOp<Device, T, Index>(
+ context,
+ sum_functor_) {}
+ private:
+ functor::UnsortedSegmentSumFunctor<Device, T, Index> sum_functor_;
+};
+
+template <typename Device, class T, class Index>
+class UnsortedSegmentMaxOp : public UnsortedSegmentBaseOp<Device, T, Index> {
+ public:
+ explicit UnsortedSegmentMaxOp(OpKernelConstruction* context)
+ : UnsortedSegmentBaseOp<Device, T, Index>(
+ context,
+ max_functor_) {}
+ private:
+ functor::UnsortedSegmentMaxFunctor<Device, T, Index> max_functor_;
+};
+
+#define REGISTER_REAL_CPU_UNSORTED_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ UnsortedSegmentSumOp<CPUDevice, type, index_type>); \
+ REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ UnsortedSegmentMaxOp<CPUDevice, type, index_type>);
+
+#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, index_type) \
REGISTER_KERNEL_BUILDER(Name("UnsortedSegmentSum") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.TypeConstraint<index_type>("Tindices"), \
UnsortedSegmentSumOp<CPUDevice, type, index_type>);
-#define REGISTER_CPU_UNSORTED_KERNELS_ALL(type) \
- REGISTER_CPU_UNSORTED_KERNELS(type, int32); \
- REGISTER_CPU_UNSORTED_KERNELS(type, int64);
-
-TF_CALL_NUMBER_TYPES(REGISTER_CPU_UNSORTED_KERNELS_ALL);
-#undef REGISTER_CPU_UNSORTED_KERNELS
-#undef REGISTER_CPU_UNSORTED_KERNELS_ALL
+#define REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL(type) \
+ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int32); \
+ REGISTER_REAL_CPU_UNSORTED_KERNELS(type, int64)
+
+#define REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(type) \
+ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int32); \
+ REGISTER_COMPLEX_CPU_UNSORTED_KERNELS(type, int64)
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL);
+REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex64);
+REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL(complex128);
+#undef REGISTER_REAL_CPU_UNSORTED_KERNELS
+#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS
+#undef REGISTER_COMPLEX_CPU_UNSORTED_KERNELS_ALL
+#undef REGISTER_REAL_CPU_UNSORTED_KERNELS_ALL
#if GOOGLE_CUDA
#define REGISTER_GPU_UNSORTED_KERNELS(type, index_type) \
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 8ed990a1e0..ee09c213b7 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -26,6 +26,17 @@ namespace tensorflow {
class OpKernelContext;
namespace functor {
+// BaseFunctor for definition of UnsorteSegmentReductionOp
+// for usage without templates.
+template <typename Device, typename T, typename Index>
+struct UnsortedSegmentBaseFunctor{
+ virtual ~UnsortedSegmentBaseFunctor(){}
+ virtual void operator()(OpKernelContext* ctx, const Device& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output){};
+};
// Functor for UnsortedSegmentSumOp.
// 'output_rows': the number of output segments (unique segment ids in
@@ -37,7 +48,7 @@ namespace functor {
// 'data': input data tensor.
// 'output': output reshaped to {output_rows, output.size/output_rows}
template <typename Device, typename T, typename Index>
-struct UnsortedSegmentSumFunctor {
+struct UnsortedSegmentSumFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
void operator()(OpKernelContext* ctx, const Device& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
@@ -45,6 +56,23 @@ struct UnsortedSegmentSumFunctor {
typename TTypes<T, 2>::Tensor output);
};
+// Functor for UnsortedSegmentMaxOp.
+// 'output_rows': the number of output segments (unique segment ids in
+// 'segment_ids').
+// 'segment_ids_shape': shape of 'segment_ids' tensor.
+// 'segment_ids': unsorted map from input to output segment ids at which to
+// perform segment sum operation.
+// 'data_size': size of input data tensor.
+// 'data': input data tensor.
+// 'output': output reshaped to {output_rows, output.size/output_rows}
+template <typename Device, typename T, typename Index>
+struct UnsortedSegmentMaxFunctor: public UnsortedSegmentBaseFunctor<Device, T, Index> {
+ void operator()(OpKernelContext* ctx, const Device& d,
+ const Index output_rows, const TensorShape& segment_ids_shape,
+ typename TTypes<Index>::ConstFlat segment_ids,
+ const Index data_size, const T* data,
+ typename TTypes<T, 2>::Tensor output);
+};
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
index e0659c8bc9..5f53f098aa 100644
--- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
@@ -56,12 +56,12 @@ namespace functor {
// UnsortedSegmentSumFunctor implementation for GPUDevice.
template <typename T, typename Index>
-struct UnsortedSegmentSumFunctor<GPUDevice, T, Index> {
+struct UnsortedSegmentSumFunctor<GPUDevice, T, Index>: UnsortedSegmentBaseFunctor<GPUDevice, T, Index> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
const Index output_rows, const TensorShape& segment_ids_shape,
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
- typename TTypes<T, 2>::Tensor output) {
+ typename TTypes<T, 2>::Tensor output) override {
if (output.size() == 0) {
return;
}
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 15bcd53322..a530d286f7 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -588,6 +588,7 @@ REGISTER_OP_GRADIENT("Mean", MeanGrad);
// REGISTER_OP_GRADIENT("SegmentMin", SegmentMinGrad);
// REGISTER_OP_GRADIENT("SegmentMax", SegmentMaxGrad);
// REGISTER_OP_GRADIENT("UnsortedSegmentSum", UnsortedSegmentSumGrad);
+// REGISTER_OP_GRADIENT("UnsortedSegmentMax", UnsortedSegmentMaxGrad);
Status MinMaxGradHelper(const string& op, const AttrSlice& attrs,
FunctionDef* g) {
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 92e2823fb2..00876bc18c 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -1342,6 +1342,36 @@ Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status UnsortedSegmentReductionShapeFn(InferenceContext* c) {
+ ShapeHandle s_data = c->input(0);
+ ShapeHandle s_segment_ids = c->input(1);
+ ShapeHandle s_num_segments = c->input(2);
+ TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
+
+ ShapeHandle out;
+
+ // Leading dimensions of data must be compatible with dimensions of
+ // <s_segment_ids>.
+ if (c->RankKnown(s_segment_ids)) {
+ TF_RETURN_IF_ERROR(
+ c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
+
+ // Get the value of the num_segments input tensor.
+ DimensionHandle num_segments_dim;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
+
+ // Output is {segment_id_rank} + s_data[segment_id_rank:].
+ ShapeHandle s_data_suffix;
+ TF_RETURN_IF_ERROR(
+ c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
+ } else {
+ out = c->UnknownShape();
+ }
+ c->set_output(0, out);
+ return Status::OK();
+}
} // namespace
REGISTER_OP("SegmentSum")
@@ -1495,36 +1525,7 @@ REGISTER_OP("UnsortedSegmentSum")
.Output("output: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle s_data = c->input(0);
- ShapeHandle s_segment_ids = c->input(1);
- ShapeHandle s_num_segments = c->input(2);
- TF_RETURN_IF_ERROR(c->WithRank(s_num_segments, 0, &s_num_segments));
-
- ShapeHandle out;
-
- // Leading dimensions of data must be compatible with dimensions of
- // <s_segment_ids>.
- if (c->RankKnown(s_segment_ids)) {
- TF_RETURN_IF_ERROR(
- c->MergePrefix(s_data, s_segment_ids, &s_data, &s_segment_ids));
-
- // Get the value of the num_segments input tensor.
- DimensionHandle num_segments_dim;
- TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &num_segments_dim));
-
- // Output is {segment_id_rank} + s_data[segment_id_rank:].
- ShapeHandle s_data_suffix;
- TF_RETURN_IF_ERROR(
- c->Subshape(s_data, c->Rank(s_segment_ids), &s_data_suffix));
- TF_RETURN_IF_ERROR(
- c->Concatenate(c->Vector(num_segments_dim), s_data_suffix, &out));
- } else {
- out = c->UnknownShape();
- }
- c->set_output(0, out);
- return Status::OK();
- })
+ .SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along segments of a tensor.
@@ -1554,6 +1555,43 @@ output: Has same shape as data, except for the first `segment_ids.rank`
)doc");
+
+REGISTER_OP("UnsortedSegmentMax")
+ .Input("data: T")
+ .Input("segment_ids: Tindices")
+ .Input("num_segments: int32")
+ .Output("output: T")
+ .Attr("T: realnumbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(UnsortedSegmentReductionShapeFn)
+ .Doc(R"doc(
+Computes the Max along segments of a tensor.
+
+Read [the section on
+Segmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation
+of segments.
+
+This operator is similar to the [unsorted segment sum operator](../../api_docs/python/math_ops.md#UnsortedSegmentSum).
+Instead of computing the sum over segments, it computes the maximum
+such that:
+
+\\(output_i = \max_j data_j\\) where max is over `j` such
+that `segment_ids[j] == i`.
+
+If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for specific numeric type,
+ `output[i] = numeric_limits<T>::min()`.
+
+<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../../images/UnsortedSegmentSum.png" alt>
+</div>
+
+segment_ids: A 1-D tensor whose rank is equal to the rank of `data`'s
+first dimension.
+
+output: Has same shape as data, except for dimension 0 which
+has size `num_segments`.
+
+)doc");
REGISTER_OP("SparseSegmentSum")
.Input("data: T")
.Input("indices: Tidx")
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 805ba2c662..73635d3417 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -25194,6 +25194,59 @@ op {
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
}
op {
+ name: "UnsortedSegmentSum"
+ input_arg {
+ name: "data"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "segment_ids"
+ description: "A tensor whose shape is a prefix of `data.shape`."
+ type_attr: "Tindices"
+ }
+ input_arg {
+ name: "num_segments"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT64
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ summary: "Computes the max along segments of a tensor."
+ description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
+}
+op {
name: "Unstage"
output_arg {
name: "values"
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index d7e3b3e79b..485530d405 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -49,12 +49,21 @@ class SegmentReductionHelper(test.TestCase):
slice_shape = x.shape[indices.ndim:]
x_flat = x.reshape((indices.size,) + slice_shape)
for i, index in enumerate(indices.ravel()):
- if output[index] is not None:
+ if (output[index] is not None) and op1 == np.max:
+ for j in range(0, output[index].shape[0]):
+ output[index][j] = op1([output[index][j], x_flat[i][j]])
+ elif output[index] is not None:
output[index] = op1(output[index], x_flat[i])
else:
output[index] = x_flat[i]
# zero initialize values that are still uncalcuated.
- output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ # output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ if not op1 == np.max:
+ output = [o if o is not None else np.zeros(slice_shape) for o in output]
+ else:
+ zeroslice = np.zeros(slice_shape)
+ zeroslice.fill(dtype.min)
+ output = [o if o is not None else zeroslice for o in output]
if op2 is not None:
output = [op2(o) for o in output]
output = [o.reshape(slice_shape) for o in output]
@@ -245,7 +254,7 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
self._assertAllClose(indices, np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
- def testGradient(self):
+ def testGradientSegmentSum(self):
num_cols = 2
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
num_segments = max(indices_flat) + 3
@@ -318,6 +327,23 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
unsorted = math_ops.unsorted_segment_sum(data, segment_ids, 2)
self.assertAllEqual(unsorted.eval(), np.zeros((2, 0), dtype=dtype))
+ def testGradientSegmentMax(self):
+ num_cols = 2
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = max(indices_flat) + 3
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (num_cols,)
+ with self.test_session():
+ tf_x, np_x = self._input(shape, dtype=dtypes_lib.float64)
+ s = math_ops.unsorted_segment_max(data=tf_x, segment_ids=indices,
+ num_segments=num_segments)
+ jacob_t, jacob_n = gradient_checker.compute_gradient(
+ tf_x,
+ shape,
+ s,
+ [num_segments, num_cols],
+ x_init_value=np_x.astype(np.double), delta=1)
+ self.assertAllClose(jacob_t, jacob_n, rtol=1e-3, atol=1e-3)
class UnsortedSegmentSumGpuTest(UnsortedSegmentSumTest):
use_gpu = True
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index e0232b35f0..b1e3e9e749 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -188,35 +188,42 @@ def _SparseSegmentSqrtNGrad(op, grad):
dim0), None, None)
-def _SegmentMinOrMaxGrad(op, grad):
- """Gradient for SegmentMin and SegmentMax. Both share the same code."""
- zeros = array_ops.zeros(
- array_ops.shape(op.inputs[0]), dtype=op.inputs[0].dtype)
+def _SegmentMinOrMaxGrad(op, grad, is_sorted):
+ """Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code."""
+ zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype)
# Get the number of selected (minimum or maximum) elements in each segment.
gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
- num_selected = math_ops.segment_sum(
- math_ops.cast(is_selected, grad.dtype), op.inputs[1])
+ if is_sorted:
+ num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
+ op.inputs[1])
+ else:
+ num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype),
+ op.inputs[1], op.inputs[2])
# Compute the gradient for each segment. The gradient for the ith segment is
# divided evenly among the selected elements in that segment.
weighted_grads = math_ops.div(grad, num_selected)
gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])
- return array_ops.where(is_selected, gathered_grads, zeros), None
+ if is_sorted:
+ return array_ops.where(is_selected, gathered_grads, zeros), None
+ else:
+ return array_ops.where(is_selected, gathered_grads, zeros), None, None
@ops.RegisterGradient("SegmentMin")
def _SegmentMinGrad(op, grad):
"""Gradient for SegmentMin."""
- return _SegmentMinOrMaxGrad(op, grad)
+ return _SegmentMinOrMaxGrad(op, grad, True)
@ops.RegisterGradient("SegmentMax")
def _SegmentMaxGrad(op, grad):
"""Gradient for SegmentMax."""
- return _SegmentMinOrMaxGrad(op, grad)
+ return _SegmentMinOrMaxGrad(op, grad, True)
@ops.RegisterGradient("UnsortedSegmentSum")
@@ -225,6 +232,11 @@ def _UnsortedSegmentSumGrad(op, grad):
return array_ops.gather(grad, op.inputs[1]), None, None
+@ops.RegisterGradient("UnsortedSegmentMax")
+def _UnsortedSegmentMaxGrad(op, grad):
+ return _SegmentMinOrMaxGrad(op, grad, False)
+
+
@ops.RegisterGradient("Abs")
def _AbsGrad(op, grad):
x = op.inputs[0]
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 11e7d8382f..1378bf8eca 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -196,6 +196,7 @@ tf.segment_sum(c, tf.constant([0, 0, 1]))
@@segment_mean
@@unsorted_segment_sum
+@@unsorted_segment_max
@@sparse_segment_sum
@@sparse_segment_mean