aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/strided_slice_op_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-15 15:29:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-15 16:33:22 -0700
commitcaceb02f75ff80a8e48440720cec3d7d6fa3297e (patch)
tree931c6b37cfd5881fe626bb246a73140ba60f3bfc /tensorflow/core/kernels/strided_slice_op_test.cc
parentcb9d147e9c788cc60ebb255fd26971719c7e2db2 (diff)
Add StridedSlice C++ shape fn and change python to use it.
Change: 133323531
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_test.cc')
-rw-r--r--tensorflow/core/kernels/strided_slice_op_test.cc35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_test.cc b/tensorflow/core/kernels/strided_slice_op_test.cc
index d4e057b49d..5a123559be 100644
--- a/tensorflow/core/kernels/strided_slice_op_test.cc
+++ b/tensorflow/core/kernels/strided_slice_op_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -30,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/util/strided_slice_op.h"
namespace tensorflow {
namespace {
@@ -87,5 +89,38 @@ static void BM_SliceBFloat16(int iters, int dim2) {
BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000);
+static void BM_ValidateStridedSliceOp(int iters) {
+ testing::StopTiming();
+ int kDim = 100;
+ int kMaxSize = 15000;
+ int size = 100;
+ Tensor begin = test::AsTensor<int32>({10, 10});
+ Tensor end = test::AsTensor<int32>({10 + kDim, 10 + size});
+ Tensor strides = test::AsTensor<int32>({1, 1});
+ TensorShape input_shape({2 * kDim, kMaxSize});
+
+ testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ TensorShape processing_shape, final_shape;
+ bool is_identity = true, slice_dim0 = true, is_simple_slice = true;
+ gtl::InlinedVector<int64, 4> begin_out, end_out, strides_out;
+ const int32 begin_mask = 0;
+ const int32 end_mask = 0;
+ const int32 ellipsis_mask = 0;
+ const int32 new_axis_mask = 0;
+ const int32 shrink_axis_mask = 0;
+
+ ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape);
+ ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape);
+ TF_CHECK_OK(ValidateStridedSliceOp(
+ begin, end, strides, ShapeReadWriteFromTensorShape(&input_shape),
+ begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
+ &wrapped_processing_shape, &wrapped_final_shape, &is_identity,
+ &is_simple_slice, &slice_dim0, &begin_out, &end_out, &strides_out));
+ }
+}
+
+BENCHMARK(BM_ValidateStridedSliceOp);
+
} // namespace
} // namespace tensorflow