aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 15:22:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 15:22:51 -0700
commit3607217a7dc296baea06f190cfaa831f6b1471e6 (patch)
tree4586dc8d65833daa6da880c7232426e2a927769a
parent004d967fdd60e84c9e749a8cbf260145e7363c2f (diff)
parent80bcaabf159a45bcd4ed5ef0e749b1787690dc44 (diff)
Merge pull request #20685 from yongtang:06252018-roll-fast-bound-check
PiperOrigin-RevId: 204199229
-rw-r--r--tensorflow/core/kernels/BUILD15
-rw-r--r--tensorflow/core/kernels/roll_op.cc3
2 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 3426ea8aa2..7599cf7db2 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -2706,17 +2706,16 @@ cc_library(
],
)
-MANIP_DEPS = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:manip_ops_op_lib",
- "//third_party/eigen3",
-]
-
tf_kernel_library(
name = "roll_op",
prefix = "roll_op",
- deps = MANIP_DEPS,
+ deps = [
+ ":bounds_check",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:manip_ops_op_lib",
+ "//third_party/eigen3",
+ ],
)
tf_cc_test(
diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc
index 722116f86f..efa30438d9 100644
--- a/tensorflow/core/kernels/roll_op.cc
+++ b/tensorflow/core/kernels/roll_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/register_types_traits.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/work_sharder.h"
@@ -258,7 +259,7 @@ class RollOp : public OpKernel {
if (axis < 0) {
axis += num_dims;
}
- OP_REQUIRES(context, 0 <= axis && axis < num_dims,
+ OP_REQUIRES(context, FastBoundsCheck(axis, num_dims),
errors::InvalidArgument("axis ", axis, " is out of range"));
const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1);
const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i));