aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/reduce_slice_ops
diff options
context:
space:
mode:
authorGravatar Andrew Harp <andrewharp@google.com>2017-08-21 12:10:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-21 12:14:49 -0700
commit6e3e7d18f42cb4237ce6dbe2ffd0f9f158c36daf (patch)
treebe8357ec40fb227cdbb1e329bdfda0bd9e6f46a2 /tensorflow/contrib/reduce_slice_ops
parent2ba7ec1680f6462122a11f1748f04e1747d8d2e0 (diff)
Merge changes from github.
END_PUBLIC --- Commit 575bd01d4 authored by Vijay Vasudevan<vrv@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Remove /replica:0 declaration in device functions and allow them to be freely bound based on cluster names present. When more than one value matches, it will choose the first lexicographically available device that matches the specification, which in practice will do pretty much the same thing as hardcoding /replica:0. PiperOrigin-RevId: 165766815 --- Commit d685bbc54 authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Benchmarks with backprop enabled (and removes overhead). Before: np.array([[3]]) took 1.50us (30000 iterations) Tensor([[3]]) took 16.30us (30000 iterations) MatMul [2, 2]: np.dot took 0.61us (30000 iterations) MatMul [2, 2]: tf.matmul took 60.53us (30000 iterations) MatMul [2, 2]: gen_math_ops.mat_mul took 25.72us (30000 iterations) MatMul [2, 2]: TFE_Py_Execute took 2.82us (30000 iterations) MatMul [2, 2]: defun(tf.matmul) took 45.70us (30000 iterations) MatMul [100, 784]: np.dot took 383.32us (1000 iterations) MatMul [100, 784]: tf.matmul took 350.35us (1000 iterations) MatMul [100, 784]: gen_math_ops.mat_mul took 315.97us (1000 iterations) MatMul [100, 784]: TFE_Py_Execute took 249.42us (1000 iterations) MatMul [100, 784]: defun(tf.matmul) took 280.95us (1000 iterations) If backprop is enabled: np.array([[3]]) took 0.83us (30000 iterations) Tensor([[3]]) took 15.21us (30000 iterations) MatMul [2, 2]: np.dot took 0.63us (30000 iterations) MatMul [2, 2]: tf.matmul took 76.31us (30000 iterations) MatMul [2, 2]: gen_math_ops.mat_mul took 38.66us (30000 iterations) MatMul [2, 2]: TFE_Py_Execute took 2.31us (30000 iterations) MatMul [2, 2]: defun(tf.matmul) took 51.96us (30000 iterations) MatMul [100, 784]: np.dot took 378.34us (1000 iterations) MatMul [100, 784]: tf.matmul took 352.09us (1000 iterations) MatMul [100, 784]: gen_math_ops.mat_mul took 364.28us (1000 iterations) MatMul [100, 784]: TFE_Py_Execute took 350.68us (1000 iterations) MatMul [100, 784]: defun(tf.matmul) took 377.19us (1000 iterations) After: np.array([[3]]) took 0.86us (30000 iterations) Tensor([[3]]) took 15.19us (30000 iterations) MatMul [2, 2]: np.dot took 0.60us (30000 iterations) MatMul [2, 2]: tf.matmul took 64.51us (30000 iterations) MatMul [2, 2]: gen_math_ops.mat_mul took 28.34us (30000 iterations) MatMul [2, 2]: TFE_Py_Execute took 2.38us (30000 iterations) MatMul [2, 2]: defun(tf.matmul) took 48.50us (30000 iterations) MatMul [100, 784]: np.dot took 475.27us (1000 iterations) MatMul [100, 784]: tf.matmul took 399.50us (1000 iterations) MatMul [100, 784]: gen_math_ops.mat_mul took 307.80us (1000 iterations) MatMul [100, 784]: TFE_Py_Execute took 272.83us (1000 iterations) MatMul [100, 784]: defun(tf.matmul) took 350.06us (1000 iterations) PiperOrigin-RevId: 165765641 --- Commit d902babbd authored by David Majnemer<majnemer@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Algebraic simplifier incorrectly transformed convolutions into bitcasts PiperOrigin-RevId: 165765575 --- Commit 8e78e10ef authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: disable test temporarily PiperOrigin-RevId: 165763204 --- Commit a271c37db authored by Benoit Steiner<bsteiner@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Small improvements to the arithmetic optimizer PiperOrigin-RevId: 165760972 --- Commit b6409594d authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Convert some tests to cover both eager and graph. PiperOrigin-RevId: 165760364 --- Commit 5ead76420 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Reduce XLA compile time by ~7% for a convolutional image model: * Added CompactPointerSet<T>, which is optimized for set size <= 1. * Changed expensive CHECKs to DCHECKS in buffer_assignment.cc * Reserve space in DFS state array before starting DFS. * Use unsigned arithmetic in DFS state maintenance. * HloInstruction: - Moved frequently used fields to start for better cache locality. - Use InlinedVector instead of vector for operand array. - Use InlinedVector instead of vector for DFS stack. * Pre-compute "is array" and "is tuple" for LogicalBuffer. * PointsToSet: - Combine two ShapeTrees into one. - Use CompactPointerSet instead of std::set to hold sources. - Use CompactPointerSet instead of std::set to hold flattened buffers. * ShapeTree: use unique_ptr instead of optional for shape storage (reduces size and destruction overhead). * Add proper const qualifiers to some FlatSet iterator methods. Co-author=jeff PiperOrigin-RevId: 165759117 --- Commit a0544b0b8 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make TPU symbols more easily accessible from contrib. PiperOrigin-RevId: 165753322 --- Commit cdc08afbb authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Slightly relax numeric tolerance for sinlge precision tests of matrix_solve_ls (and tighten it for double precision). PiperOrigin-RevId: 165750936 --- Commit eebcc861a authored by Jianwei Xie<xiejw@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixed the race condition between multi eval step increments. PiperOrigin-RevId: 165750595 --- Commit bbc0b8471 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 165748384 --- Commit 65f87c967 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Change device string in RecvNodeDescriptor in VirtualScheduler from const reference to const as the RecvNodeDescriptor (and cached_recv_nodes map) outlives device string from the NodeDef. PiperOrigin-RevId: 165748244 --- Commit 57b0276cf authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 165747467 --- Commit 64e54423b authored by Derek Murray<mrry@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf.contrib.data] Fix nested dictionary handling in dataset elements. Backports recent changes to the core version of the nest.py library. Fixes #12372. PiperOrigin-RevId: 165746517 --- Commit 378463ae8 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Make tf.eye accept Python integer shapes and avoid generating unnecessary shape handling ops. Clean up test and add tests with placeholders. PiperOrigin-RevId: 165746090 --- Commit 109ecf823 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add support for complex in matrix_solve_ls_op. Split into separate files for each data type to speed up build. PiperOrigin-RevId: 165744539 --- Commit 51441302d authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Internal change. PiperOrigin-RevId: 165737455 --- Commit d0cb32c2a authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Docstring for ResourceVariable. PiperOrigin-RevId: 165735441 --- Commit 32f4c5b6e authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Add IsFinite op in tf2xla. PiperOrigin-RevId: 165734702 --- Commit 5f5c3eb0a authored by Mark Daoust<markdaoust@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Move "supervisor.md" from programmer's guide to api_guides. PiperOrigin-RevId: 165732026 --- Commit d001b58de authored by Derek Murray<mrry@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [tf.contrib.data] Fix handling of multi-output tf.py_func() in Dataset.map(). If the `map_func` returns a list of tensors, the current code will attempt to stack it into a single tensor and raise an unintuitive error. Some multi-output ops (such as `tf.py_func()`) return lists of typically-not-stackable tensors. This change treats lists returned from `map_func` as tuples; users who were relying on this auto-stacking behavior should manually call `tf.stack()` (or `tf.convert_to_tensor()`) on the list being returned. Fixes #12396. PiperOrigin-RevId: 165731970 --- Commit e6c60fb36 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix flakyness, sometimes the op takes ms to run. PiperOrigin-RevId: 165728705 --- Commit 360bff8ae authored by Ali Yahya<alive@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Makes tape.watch() work with ResourceVariables. To this end, also adds a property, `device`, to TensorNode. PiperOrigin-RevId: 165726368 --- Commit 80bd004cd authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Implements SVDF model for keyword spotting tutorial. PiperOrigin-RevId: 165725938 --- Commit aaabf6b90 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix bug: Using a ComputationDataHandle from the wrong ComputationBuilder. PiperOrigin-RevId: 165724017 --- Commit 107d165d9 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Use 2-arg TraceMe constructor to prevent unnecessary StrCat computation when tracing is disabled. PiperOrigin-RevId: 165722280 --- Commit 7d01f89cc authored by Pete Warden<petewarden@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Android demo app for speech recognition PiperOrigin-RevId: 165714459 --- Commit a6729325a authored by Alexandre Passos<apassos@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Deletes convert_n_to_eager_tensor. Moves convert_to_eager_tensor to constant_op. PiperOrigin-RevId: 165704074 --- Commit 573b303ac authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BUILD cleanup in tensorflow/core/kernels PiperOrigin-RevId: 165688864 --- Commit 711be6adc authored by Derek Murray<mrry@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: `Dataset.from_generator()` constructs a dataset from a Python generator. With this change, it becomes possible to use a Python generator as the source dataset for a `tf.contrib.data` input pipeline. This enables easier integration with non-TensorFlow data sources. The generator can yield a nested structure of NumPy arrays, or values convertible to NumPy arrays. This addresses a concern raised in issue #7951. PiperOrigin-RevId: 165663857 --- Commit 00594ecdd authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: New landing page and leftnav for Programmer's Guide. PiperOrigin-RevId: 165660897 --- Commit 7359fec79 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Implement Batchnorm Inference by expanding them into smaller ops. 1. Add batch norm inference support in batchnorm_rewriter 2. Connect xla's batchnorm inference to tf's FusedBatchNorm RELNOTES: n/a PiperOrigin-RevId: 165655351 --- Commit f0da8bf56 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [Rematerialization] Reconsider to remat operations with control dependencies We added a conservartive logic to not rematerialize operations with control dependencies since the rematerialized operations could result in undesired ordering. However, we now realize that when we remat an operation, we also copy the dependencies of them, which guarantees the rematerialized operation has the same constraint as the original operation. PiperOrigin-RevId: 165654629 --- Commit a1225879c authored by Chris Leary<leary@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: [XLA] Propagate error code in computation replay tool. PiperOrigin-RevId: 165654497 --- Commit 513def0bb authored by Benoit Steiner<bsteiner@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixed BuildOpInfoWithoutDevice PiperOrigin-RevId: 165653933 --- Commit d7e425f0b authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix linear algebra benchmarks. PiperOrigin-RevId: 165653891 --- Commit 465c40819 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fix the shape information propagation for Enter op. PiperOrigin-RevId: 165653579 --- Commit c0198fd8d authored by Derek Murray<derek.murray@gmail.com> Committed by gunan<gunan@google.com>: [CMake] Add missing dependencies on boosted_trees protos and other fixes (#12315) * [CMake] Add missing dependencies * Avoid rebuilding boosted_trees protos for Python. * Add GPU implementation ZeroInitializerOp to the CMake build. --- Commit 641943fd7 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Update ops-related pbtxt files. PiperOrigin-RevId: 165652758 --- Commit e31346452 authored by Jonathan Hseu<jhseu@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: TPUEstimator: Fix the outfeed thread join. PiperOrigin-RevId: 165651781 --- Commit 565a9d350 authored by Vijay Vasudevan<vrv@google.com> Committed by Andrew Harp<andrewharp@users.noreply.github.com>: Add missing 'type' keyword to ArgumentParser add_argument (#12275) Fixes #12210 --- Commit 19a55725a authored by Rohan Jain<rohanj@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Allowing functions to run across devices. This change expands the ProcessFunctionLibraryRuntime library to Instantiate and Run functions on different devices. When a FunctionLibraryRuntime encounters a function with a target that is another device, it delegates Instantiate() and Run() calls to the ProcessFunctionLibraryRuntime. This change also moves the table_ containing all function instantiations to the PFLR instead of the FunctionLibraryRuntime. PiperOrigin-RevId: 165651194 --- Commit 8c0853db7 authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add a test for negative and zero pow() input. PiperOrigin-RevId: 165650096 --- Commit a3c4e980e authored by Pete Warden<petewarden@google.com> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Fixed input shape for freezing audio graphs PiperOrigin-RevId: 165649546 --- Commit 9b9e5989d authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: Add a call_logit_fn utility for logit_fn's, similar to Estimator's _call_model_fn. PiperOrigin-RevId: 165649388 --- Commit 4ff1f4442 authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Remove the script as well if building tf_nightly. --- Commit 373d78987 authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Adding the break. --- Commit 0139ac983 authored by Amit Patankar<amitpatankar@google.com> Committed by Amit Patankar<amitpatankar@google.com>: Remove tensorboard as a required package if we are building tf_nightly. --- Commit a92bd5d5c authored by A. Unique TensorFlower<gardener@tensorflow.org> Committed by TensorFlower Gardener<gardener@tensorflow.org>: BEGIN_PUBLIC Automated g4 rollback of changelist 165630063 PiperOrigin-RevId: 165957821
Diffstat (limited to 'tensorflow/contrib/reduce_slice_ops')
-rw-r--r--tensorflow/contrib/reduce_slice_ops/BUILD110
-rw-r--r--tensorflow/contrib/reduce_slice_ops/__init__.py26
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc239
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h84
-rw-r--r--tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc100
-rw-r--r--tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc282
-rw-r--r--tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc41
-rw-r--r--tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py158
-rw-r--r--tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py30
9 files changed, 1070 insertions, 0 deletions
diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD
new file mode 100644
index 0000000000..5340a51e00
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/BUILD
@@ -0,0 +1,110 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs")
+load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_kernel_tests_linkstatic")
+
+tf_custom_op_library(
+ name = "python/ops/_reduce_slice_ops.so",
+ srcs = [
+ "kernels/reduce_slice_ops.cc",
+ "kernels/reduce_slice_ops.h",
+ "ops/reduce_slice_ops.cc",
+ ],
+ gpu_srcs = [
+ "kernels/reduce_slice_ops.h",
+ "kernels/reduce_slice_ops_gpu.cu.cc",
+ ],
+)
+
+tf_kernel_library(
+ name = "reduce_slice_ops_kernels",
+ srcs = [
+ "kernels/reduce_slice_ops.cc",
+ ],
+ hdrs = [
+ "kernels/reduce_slice_ops.h",
+ ],
+ gpu_srcs = [
+ "kernels/reduce_slice_ops.h",
+ "kernels/reduce_slice_ops_gpu.cu.cc",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["reduce_slice_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "reduce_slice_ops",
+ deps = [":reduce_slice_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "reduce_slice_ops_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/reduce_slice_ops.py",
+ ],
+ dso = [
+ ":python/ops/_reduce_slice_ops.so",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":reduce_slice_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:framework",
+ ],
+)
+
+cuda_py_test(
+ name = "reduce_slice_ops_test",
+ size = "small",
+ srcs = ["python/kernel_tests/reduce_slice_ops_test.py"],
+ additional_deps = [
+ ":reduce_slice_ops_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_cc_test(
+ name = "reduce_slice_ops_test_cc",
+ size = "small",
+ srcs = [
+ "ops/reduce_slice_ops_test.cc",
+ ],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":reduce_slice_ops_op_lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/reduce_slice_ops/__init__.py b/tensorflow/contrib/reduce_slice_ops/__init__.py
new file mode 100644
index 0000000000..d0364587b5
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/__init__.py
@@ -0,0 +1,26 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""reduce by slice
+
+@@reduce_slice_sum
+@@reduce_slice_prod
+@@reduce_slice_min
+@@reduce_slice_max
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.reduce_slice_ops.python.ops import *
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
new file mode 100644
index 0000000000..2def4f3f17
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.cc
@@ -0,0 +1,239 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
+#include <algorithm>
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+namespace tensorflow {
+
+using GPUDevice = Eigen::GpuDevice;
+using CPUDevice = Eigen::ThreadPoolDevice;
+using thread::ThreadPool;
+
+namespace functor {
+
+#define CPUReduceSliceFunctorReduceop(reduceop, beginning) \
+ template <typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop<CPUDevice, T, Index> { \
+ private: \
+ struct XYZ { \
+ Index x, y, z; \
+ XYZ() = default; \
+ XYZ(Index x, Index y, Index z) : x(x), y(y), z(z) {} \
+ }; \
+ inline static XYZ global_index_to_xyz(Index global, XYZ size) { \
+ XYZ ret; \
+ ret.x = global / (size.y * size.z); \
+ ret.y = global % (size.y * size.z) / size.z; \
+ ret.z = global % size.z; \
+ return ret; \
+ } \
+ \
+ public: \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext* ctx, const CPUDevice& d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output) { \
+ Index bound = data.dimension(1); \
+ Index dim1 = output.dimension(0); \
+ Index dim2 = output.dimension(1); \
+ Index dim3 = output.dimension(2); \
+ Index size = dim1 * dim2 * dim3; \
+ if (size == 0) { \
+ return; \
+ } \
+ T zero = beginning<T>(); \
+ ThreadPool* thread_pool = \
+ ctx->device()->tensorflow_cpu_worker_threads()->workers; \
+ /* shard the work */ \
+ auto work = [&](Index start, Index end) { \
+ for (Index global = start; global < end; ++global) { \
+ XYZ xyz = global_index_to_xyz(global, XYZ(dim1, dim2, dim3)); \
+ Index x = xyz.x; \
+ Index y = xyz.y; \
+ Index z = xyz.z; \
+ output(x, y, z) = zero; \
+ Index slice_head = indices(y * indices_width); \
+ Index slice_end = std::min(indices(y * indices_width + 1), bound); \
+ for (Index i = slice_head; i < slice_end; ++i) { \
+ output(x, y, z) = reduceop(output(x, y, z), data(x, i, z)); \
+ } \
+ } \
+ }; \
+ /* Here assumes the number of average CPU cycles for each slice equals \
+ * the average length of each slice */ \
+ thread_pool->ParallelFor(size, std::max(bound / dim2, (Index)1), work); \
+ } \
+ };
+
+CALL_ALL_REDUCEOPS(CPUReduceSliceFunctorReduceop)
+#undef CPUReduceSliceFunctorReduceop
+
+#define DEFINE_CPU_SUMPROD_SPECS_INDEX(T, Index) \
+ template struct ReduceSliceFunctorSum<CPUDevice, T, Index>; \
+ template struct ReduceSliceFunctorProd<CPUDevice, T, Index>;
+
+#define DEFINE_CPU_MINMAX_SPECS_INDEX(T, Index) \
+ template struct ReduceSliceFunctorMax<CPUDevice, T, Index>; \
+ template struct ReduceSliceFunctorMin<CPUDevice, T, Index>;
+
+#define DEFINE_CPU_SUMPROD_SPECS(T) \
+ DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int32); \
+ DEFINE_CPU_SUMPROD_SPECS_INDEX(T, int64);
+
+#define DEFINE_CPU_MINMAX_SPECS(T) \
+ DEFINE_CPU_MINMAX_SPECS_INDEX(T, int32); \
+ DEFINE_CPU_MINMAX_SPECS_INDEX(T, int64);
+
+TF_CALL_NUMBER_TYPES(DEFINE_CPU_SUMPROD_SPECS)
+TF_CALL_REAL_NUMBER_TYPES(DEFINE_CPU_MINMAX_SPECS)
+
+#undef DEFINE_CPU_SUMPROD_SPECS_INDEX
+#undef DEFINE_CPU_MINMAX_SPECS_INDEX
+#undef DEFINE_CPU_SUMPROD_SPECS
+#undef DEFINE_CPU_MINMAX_SPECS
+
+} // namespace functor
+
+template <typename Device, typename T, typename Index,
+ template <typename Device2, typename T2, typename Index2>
+ class Functor>
+class ReduceSliceKernel : public OpKernel {
+ public:
+ explicit ReduceSliceKernel(OpKernelConstruction* context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& data = context->input(0);
+ const Tensor& indices = context->input(1);
+ const Tensor& _axis = context->input(2);
+ int64 axis = _axis.scalar<int64>()();
+
+ int indices_width = 2;
+ int out_axis_dim_size = indices.shape().dim_size(0);
+ if (indices.dims() == 1 || indices.shape().dim_size(1) == 1) {
+ indices_width = 1;
+ if (out_axis_dim_size > 0) {
+ out_axis_dim_size--;
+ }
+ }
+
+ TensorShape output_shape = data.shape();
+ output_shape.set_dim(axis, out_axis_dim_size);
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
+ auto functor = Functor<Device, T, Index>();
+ functor(context, context->eigen_device<Device>(), indices_width,
+ indices.flat<Index>(), data.flat_inner_outer_dims<T, 3>(axis - 1),
+ output->flat_inner_outer_dims<T, 3>(axis - 1));
+ }
+};
+
+#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorSum>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorProd>);
+
+#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMax>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<CPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMin>);
+
+#define REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS(type, int64);
+
+#define REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL)
+TF_CALL_NUMBER_TYPES(REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL)
+
+#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS
+#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS
+#undef REGISTER_CPU_SUMPROD_REDUCE_SLICE_KERNELS_ALL
+#undef REGISTER_CPU_MINMAX_REDUCE_SLICE_KERNELS_ALL
+
+#if GOOGLE_CUDA
+
+#define REGISTER_GPU_REDUCE_SLICE_KERNELS(type, index_type) \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceSum") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorSum>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceProd") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorProd>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMax") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMax>); \
+ REGISTER_KERNEL_BUILDER(Name("ReduceSliceMin") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("axis") \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices"), \
+ ReduceSliceKernel<GPUDevice, type, index_type, \
+ functor::ReduceSliceFunctorMin>);
+
+#define REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL(type) \
+ REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int32); \
+ REGISTER_GPU_REDUCE_SLICE_KERNELS(type, int64);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL);
+
+#undef REGISTER_GPU_REDUCE_SLICE_KERNELS
+#undef REGISTER_GPU_REDUCE_SLICE_KERNELS_ALL
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
new file mode 100644
index 0000000000..c62a7b20d6
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+
+#define Sum(a, b) ((a) + (b))
+#define Prod(a, b) ((a) * (b))
+#define Max(a, b) ((a) > (b) ? (a) : (b))
+#define Min(a, b) ((a) < (b) ? (a) : (b))
+
+namespace tensorflow {
+
+class OpKernelContext;
+
+namespace functor {
+
+namespace reduce_functions {
+
+template <typename T>
+inline T zero() {
+ return T(0);
+}
+
+template <typename T>
+inline T one() {
+ return T(1);
+}
+
+template <typename T>
+inline T infinity() {
+ return std::max<T>(std::numeric_limits<T>::max(),
+ std::numeric_limits<T>::infinity());
+}
+
+template <typename T>
+inline T negative_infinity() {
+ return std::min<T>(-std::numeric_limits<T>::infinity(),
+ std::numeric_limits<T>::min());
+}
+
+} // namespace reduce_functions
+
+#define CALL_ALL_REDUCEOPS(func, ...) \
+ func(Sum, functor::reduce_functions::zero, ##__VA_ARGS__) \
+ func(Prod, functor::reduce_functions::one, ##__VA_ARGS__) func( \
+ Max, functor::reduce_functions::negative_infinity, ##__VA_ARGS__) \
+ func(Min, functor::reduce_functions::infinity, ##__VA_ARGS__)
+
+#define ReduceSliceFunctorReduceop(reduceop, dummy) \
+ template <typename Device, typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop { \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext* ctx, const Device& d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output); \
+ };
+
+CALL_ALL_REDUCEOPS(ReduceSliceFunctorReduceop)
+#undef ReduceSliceFunctorReduceop
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_PARTIAL_REDUCTION_OPS_H_
diff --git a/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
new file mode 100644
index 0000000000..8b205f7dd5
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops_gpu.cu.cc
@@ -0,0 +1,100 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/contrib/reduce_slice_ops/kernels/reduce_slice_ops.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+
+using GPUDevice = Eigen::GpuDevice;
+
+namespace functor {
+
+#define GPUReduceSliceFunctorReduceop(reduceop, beginning) \
+ template <typename T, typename Index> \
+ __global__ void ReduceSliceDeviceKernel##reduceop( \
+ Cuda3DLaunchConfig config, Index indices_width, Index bound, \
+ const T begin, const Index *indices, const T *input, T *out) { \
+ CUDA_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) { \
+ CUDA_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) { \
+ CUDA_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) { \
+ Index outidx = x * config.virtual_thread_count.y * \
+ config.virtual_thread_count.z + \
+ y * config.virtual_thread_count.z + z; \
+ out[outidx] = begin; \
+ Index start = indices[y * indices_width]; \
+ Index end = Min(bound, indices[y * indices_width + 1]); \
+ for (Index yin = start; yin < end; yin++) { \
+ Index inidx = x * bound * config.virtual_thread_count.z + \
+ yin * config.virtual_thread_count.z + z; \
+ out[outidx] = reduceop(out[outidx], input[inidx]); \
+ } \
+ } \
+ } \
+ } \
+ } \
+ \
+ template <typename T, typename Index> \
+ struct ReduceSliceFunctor##reduceop<GPUDevice, T, Index> { \
+ virtual ~ReduceSliceFunctor##reduceop() {} \
+ virtual void operator()(OpKernelContext *ctx, const GPUDevice &d, \
+ Index indices_width, \
+ typename TTypes<Index, 1>::ConstTensor indices, \
+ typename TTypes<T, 3>::ConstTensor data, \
+ typename TTypes<T, 3>::Tensor output) { \
+ Index bound = data.dimension(1); \
+ int sizex = output.dimension(0); \
+ int sizey = output.dimension(1); \
+ int sizez = output.dimension(2); \
+ if (sizex * sizey * sizez == 0) { \
+ return; \
+ } \
+ Cuda3DLaunchConfig config = GetCuda3DLaunchConfig( \
+ sizex, sizey, sizez, d, ReduceSliceDeviceKernel##reduceop<T, Index>, \
+ 0, 0); \
+ \
+ ReduceSliceDeviceKernel##reduceop<T, Index> \
+ <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( \
+ config, indices_width, bound, beginning<T>(), indices.data(), \
+ data.data(), output.data()); \
+ } \
+ };
+
+CALL_ALL_REDUCEOPS(GPUReduceSliceFunctorReduceop)
+#undef GPUReduceSliceFunctorReduceop
+
+#define DEFINE_GPU_REDUCEOP_SPECS_INDEX(reduceop, dummy, T) \
+ template struct ReduceSliceFunctor##reduceop<GPUDevice, T, int32>; \
+ template struct ReduceSliceFunctor##reduceop<GPUDevice, T, int64>;
+
+#define DEFINE_GPU_SPECS(T) \
+ CALL_ALL_REDUCEOPS(DEFINE_GPU_REDUCEOP_SPECS_INDEX, T)
+
+TF_CALL_REAL_NUMBER_TYPES(DEFINE_GPU_SPECS)
+
+#undef DEFINE_GPU_REDUCEOP_SPECS_INDEX
+#undef DEFINE_GPU_SPECS
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc
new file mode 100644
index 0000000000..b8b56c0e22
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops.cc
@@ -0,0 +1,282 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
+namespace {
+
+Status ReduceSliceShapeFn(InferenceContext* c) {
+ ShapeHandle handle;
+ DimensionHandle dimhandle;
+ DimensionHandle dim_axis = c->UnknownDim();
+ // "axis" must be a scala
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &handle));
+ // "data" must have rank at least 1
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &handle));
+ // "indices" must have have rank 1 or rank 2 with the number of columns must
+ // be 2
+ if (c->RankKnown(c->input(1))) {
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &handle));
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 2, &handle));
+ if (c->Rank(c->input(1)) == 1) {
+ // if "indices" is a vector of 0 elements, then the axis dimension of
+ // output tensor should be of dimension 0.
+ DimensionHandle raw_dim_axis;
+ TF_RETURN_IF_ERROR(c->Max(c->Dim(c->input(1), 0), 1, &raw_dim_axis));
+ TF_RETURN_IF_ERROR(c->Subtract(raw_dim_axis, 1, &dim_axis));
+ } else { // c->Rank(c->input(1)) == 2
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(c->input(1), 1), c->MakeDim(2), &dimhandle));
+ dim_axis = c->Dim(c->input(1), 0);
+ }
+ }
+ // shape of output tensor
+ const Tensor* _axis = c->input_tensor(2);
+ if (nullptr == _axis) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(0))));
+ } else {
+ int64 axis = _axis->scalar<int64>()();
+ TF_RETURN_IF_ERROR(c->ReplaceDim(handle, axis, dim_axis, &handle));
+ c->set_output(0, handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ReduceSliceSum")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically sum over the first dimension of a tensor according to start and end
+indices specified at 'index'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 2, 3]
+ [ 40, 50, 60]
+ [ 700, 800, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 2, 3]
+ [ 0, 0, 0]
+ [41,52,63]].
+```
+
+The data must be at least rank 1. The indices must be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be zero. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound.
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed sum values.
+)doc");
+
+REGISTER_OP("ReduceSliceProd")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the product over the first dimension of a tensor according
+to start and end indices specified at 'indices'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 2, 3]
+ [ 40, 50, 60]
+ [ 700, 800, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 2, 3]
+ [ 1, 1, 1]
+ [40,100,180]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+REGISTER_OP("ReduceSliceMax")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the maximum over the first dimension of a tensor according
+to start and end indices specified at "indices".
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 20, 3]
+ [ 400, 5, 60]
+ [ 70, 8, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 20, 3]
+ [ -BIG_VALUE, -BIG_VALUE, -BIG_VALUE]
+ [ 400, 20, 60]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+REGISTER_OP("ReduceSliceMin")
+ .Input("data: T")
+ .Input("indices: Tindices")
+ .Input("axis: int64")
+ .Output("output: T")
+ .Attr("T: numbertype")
+ .Attr("Tindices: {int32,int64}")
+ .SetShapeFn(ReduceSliceShapeFn)
+ .Doc(R"doc(
+Dynamically compute the minimum over the first dimension of a tensor according
+to start and end indices specified at 'indices'.
+
+For example:
+
+```prettyprint
+# if 'data' is [[ 1, 20, 3]
+ [ 400, 5, 60]
+ [ 70, 8, 900]
+ [1000,2000,3000]],
+
+and 'indices' is [[0,1]
+ [1,1]
+ [0,2]],
+
+the the output will be [[ 1, 20, 3]
+ [ +BIG_VALUE, +BIG_VALUE, +BIG_VALUE]
+ [ 1, 5, 3]].
+```
+
+The data must be at least rank 1. The indices can be of shape (?,2) where the
+first column is start indices and the second column is end indices. The end indices
+are not included in the reduce operation, which means, if you want to do a reduce
+over indices 0,1,2, then you should have start index 0 and end index 3. If end
+index is smaller than or equal to start, the result will be 1. If end index is
+out of bounds, then the reduce operation will automatically stop at the bound, so
+feel free to put a large number as your end of your index if you want to do the
+reduction until the bound. The indices can also be of shape (?), in this case, the
+start index of i will be the element at i, then end index of i will be the element
+at i+1. That is:
+
+```prettyprint
+indices = [0,5,11,115]
+
+is equivalent to
+
+indices = [ [0,5],
+ [5,11],
+ [11,115]]
+```
+
+data: The source of data where the computation will be taken from.
+indices: start, end indices that controls which part to be included.
+T: the type of data.
+Tindices: the type of indices, must be int32 or int64.
+output: the computed product values.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc
new file mode 100644
index 0000000000..777ad9bf15
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/ops/reduce_slice_ops_test.cc
@@ -0,0 +1,41 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/shape_inference_testutil.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+TEST(ReduceSliceOpsTest, ReduceSliceSum_ShapeFn) {
+ ShapeInferenceTestOp op("ReduceSliceSum");
+ INFER_OK(op, "?;?;?", "?");
+ INFER_OK(op, "[10,20];[100,2];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[?,2];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[0];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[1];[]", "[?,?]");
+ INFER_OK(op, "[10,20];[?];[]", "[?,?]");
+ INFER_OK(op, "[?,?];[?,2];[]", "[?,?]");
+ INFER_OK(op, "[?,?];[25,2];[]", "[?,?]");
+ INFER_OK(op, "[?];[123,2];[]", "[?]");
+ INFER_OK(op, "[1,2,3,4];[100,2];[]", "[?,?,?,?]");
+
+ INFER_ERROR("must be rank 0", op, "?;[?,2];[?]");
+ INFER_ERROR("must be at least rank 1", op, "?;[];[]");
+ INFER_ERROR("must be at most rank 2", op, "?;[1,2,3];[]");
+ INFER_ERROR("must be equal, but are 1 and 2", op, "?;[?,1];[]");
+ INFER_ERROR("must be at least rank 1", op, "[];?;[]");
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
new file mode 100644
index 0000000000..8c8db295ff
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/python/kernel_tests/reduce_slice_ops_test.py
@@ -0,0 +1,158 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.contrib.reduce_slice_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import unittest
+
+from tensorflow.contrib.reduce_slice_ops.python.ops import reduce_slice_ops
+from tensorflow.python.framework.test_util import TensorFlowTestCase
+from tensorflow.python.platform import googletest
+
+
+class ReduceSliceTest(TensorFlowTestCase):
+
+ def testReduceSliceSum1D(self):
+ x = np.array([1, 40, 700], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array([1, 741, 40, 740, 41], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum2D(self):
+ x = np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
+ [41, 52, 63]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum3D(self):
+ x = np.array(
+ [[[1, 2], [3, 4]], [[50, 60], [70, 80]], [[600, 700], [800, 900]]],
+ dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[[1, 2], [3, 4]], [[651, 762], [873, 984]], [[50, 60], [70, 80]],
+ [[650, 760], [870, 980]], [[51, 62], [73, 84]]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSumAxis1(self):
+ x = np.transpose(
+ np.array([[1, 2, 3], [40, 50, 60], [700, 800, 900]], dtype=np.int32))
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.transpose(
+ np.array(
+ [[1, 2, 3], [741, 852, 963], [40, 50, 60], [740, 850, 960],
+ [41, 52, 63]],
+ dtype=np.int32))
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 1).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceSum1DIndices(self):
+ x = np.array(
+ [[1, 2, 3], [40, 50, 60], [700, 800, 900], [1000, 2000, 3000],
+ [40000, 50000, 60000]],
+ dtype=np.int32)
+ indices = np.array([0, 0, 2, 5], dtype=np.int32)
+ result = np.array(
+ [[0, 0, 0], [41, 52, 63], [41700, 52800, 63900]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceProd(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [28, 80, 162], [4, 5, 6], [28, 40, 54], [4, 10, 18]],
+ dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_prod(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceMax(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [7, 8, 9], [4, 5, 6], [7, 8, 9], [4, 5, 6]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_max(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceMin(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.array(
+ [[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6], [1, 2, 3]], dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_min(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyDataRows(self):
+ x = np.empty((0, 1, 2, 3, 4, 5, 6), dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.zeros((5, 1, 2, 3, 4, 5, 6), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyDataCols(self):
+ x = np.empty((100, 0, 2, 3, 4, 5, 6), dtype=np.int32)
+ indices = np.array([[0, 1], [0, 3], [1, 2], [1, 3], [0, 2]], dtype=np.int32)
+ result = np.empty((5, 0, 2, 3, 4, 5, 6), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmptyIndicesRows(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.empty((0, 2), dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmpty0Indices1D(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.empty((0,), dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+ def testReduceSliceEmpty1Indices1D(self):
+ x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32)
+ indices = np.array([0], dtype=np.int32)
+ result = np.empty((0, 3), dtype=np.int32)
+ with self.test_session(use_gpu=True):
+ y_tf = reduce_slice_ops.reduce_slice_sum(x, indices, 0).eval()
+ self.assertAllEqual(y_tf, result)
+
+
+if __name__ == "__main__":
+ googletest.main()
diff --git a/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py b/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py
new file mode 100644
index 0000000000..d0f02489bd
--- /dev/null
+++ b/tensorflow/contrib/reduce_slice_ops/python/ops/reduce_slice_ops.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for the reduce slice operators."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.util import loader
+from tensorflow.python.platform import resource_loader
+
+_reduce_slice_ops = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_reduce_slice_ops.so"))
+
+reduce_slice_sum = _reduce_slice_ops.reduce_slice_sum
+reduce_slice_prod = _reduce_slice_ops.reduce_slice_prod
+reduce_slice_max = _reduce_slice_ops.reduce_slice_max
+reduce_slice_min = _reduce_slice_ops.reduce_slice_min