diff options
author | 2018-07-11 15:22:41 -0700 | |
---|---|---|
committer | 2018-07-11 15:22:51 -0700 | |
commit | 3607217a7dc296baea06f190cfaa831f6b1471e6 (patch) | |
tree | 4586dc8d65833daa6da880c7232426e2a927769a | |
parent | 004d967fdd60e84c9e749a8cbf260145e7363c2f (diff) | |
parent | 80bcaabf159a45bcd4ed5ef0e749b1787690dc44 (diff) |
Merge pull request #20685 from yongtang:06252018-roll-fast-bound-check
PiperOrigin-RevId: 204199229
-rw-r--r-- | tensorflow/core/kernels/BUILD | 15 | ||||
-rw-r--r-- | tensorflow/core/kernels/roll_op.cc | 3 |
2 files changed, 9 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3426ea8aa2..7599cf7db2 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2706,17 +2706,16 @@ cc_library( ], ) -MANIP_DEPS = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:manip_ops_op_lib", - "//third_party/eigen3", -] - tf_kernel_library( name = "roll_op", prefix = "roll_op", - deps = MANIP_DEPS, + deps = [ + ":bounds_check", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:manip_ops_op_lib", + "//third_party/eigen3", + ], ) tf_cc_test( diff --git a/tensorflow/core/kernels/roll_op.cc b/tensorflow/core/kernels/roll_op.cc index 722116f86f..efa30438d9 100644 --- a/tensorflow/core/kernels/roll_op.cc +++ b/tensorflow/core/kernels/roll_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types_traits.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/work_sharder.h" @@ -258,7 +259,7 @@ class RollOp : public OpKernel { if (axis < 0) { axis += num_dims; } - OP_REQUIRES(context, 0 <= axis && axis < num_dims, + OP_REQUIRES(context, FastBoundsCheck(axis, num_dims), errors::InvalidArgument("axis ", axis, " is out of range")); const int ds = std::max<int>(static_cast<int>(input.dim_size(axis)), 1); const int sum = shift_mod_sum[axis] + static_cast<int>(shift_flat(i)); |