aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/strided_slice_op_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_test.cc')
-rw-r--r--tensorflow/core/kernels/strided_slice_op_test.cc49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_test.cc b/tensorflow/core/kernels/strided_slice_op_test.cc
index 281ca0f58f..78bb15463c 100644
--- a/tensorflow/core/kernels/strided_slice_op_test.cc
+++ b/tensorflow/core/kernels/strided_slice_op_test.cc
@@ -76,20 +76,69 @@ static void SliceHelper(int iters, int size) {
testing::UseRealTime();
}
+template <typename T>
+static void Dim8SliceHelper(int iters, int size) {
+ testing::StopTiming();
+ Graph* g = new Graph(OpRegistry::Global());
+ DataType dt = DataTypeToEnum<T>::v();
+ int kDim = 100;
+ int kMaxSize = 15000;
+ CHECK_LT(size, kMaxSize);
+
+ Tensor begin(DT_INT32, TensorShape({8}));
+ begin.flat<int32>()(10) = 10;
+ for (int i = 1; i < 7; ++i) {
+ begin.flat<int32>()(i) = 0;
+ }
+ begin.flat<int32>()(7) = 10;
+
+ Tensor end(DT_INT32, TensorShape({8}));
+ end.flat<int32>()(0) = 10 + kDim;
+ for (int i = 1; i < 7; ++i) {
+ end.flat<int32>()(i) = 1;
+ }
+ end.flat<int32>()(7) = 10 + size;
+
+ Tensor strides(DT_INT32, TensorShape({8}));
+ for (int i = 0; i < 8; ++i) {
+ strides.flat<int32>()(i) = 1;
+ }
+
+ Tensor input(dt, TensorShape({2*kDim, 1, 1, 1, 1, 1, 1, kMaxSize}));
+ input.flat<T>().setRandom();
+
+ Node* node;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "StridedSlice")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, begin))
+ .Input(test::graph::Constant(g, end))
+ .Input(test::graph::Constant(g, strides))
+ .Attr("T", dt)
+ .Finalize(g, &node));
+
+ testing::BytesProcessed(static_cast<int64>(iters) * kDim * size * sizeof(T));
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+ testing::UseRealTime();
+}
+
static void BM_SliceFloat(int iters, int dim2) {
SliceHelper<float>(iters, dim2);
+ Dim8SliceHelper<float>(iters, dim2);
}
BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000);
static void BM_SliceComplex64(int iters, int dim2) {
SliceHelper<std::complex<float>>(iters, dim2);
+ Dim8SliceHelper<std::complex<float>>(iters, dim2);
}
BENCHMARK(BM_SliceComplex64)->Arg(100)->Arg(1000)->Arg(10000);
static void BM_SliceBFloat16(int iters, int dim2) {
SliceHelper<bfloat16>(iters, dim2);
+ Dim8SliceHelper<bfloat16>(iters, dim2);
}
BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000);