diff options
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op_test.cc | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_test.cc b/tensorflow/core/kernels/strided_slice_op_test.cc index 281ca0f58f..78bb15463c 100644 --- a/tensorflow/core/kernels/strided_slice_op_test.cc +++ b/tensorflow/core/kernels/strided_slice_op_test.cc @@ -76,20 +76,69 @@ static void SliceHelper(int iters, int size) { testing::UseRealTime(); } +template <typename T> +static void Dim8SliceHelper(int iters, int size) { + testing::StopTiming(); + Graph* g = new Graph(OpRegistry::Global()); + DataType dt = DataTypeToEnum<T>::v(); + int kDim = 100; + int kMaxSize = 15000; + CHECK_LT(size, kMaxSize); + + Tensor begin(DT_INT32, TensorShape({8})); + begin.flat<int32>()(10) = 10; + for (int i = 1; i < 7; ++i) { + begin.flat<int32>()(i) = 0; + } + begin.flat<int32>()(7) = 10; + + Tensor end(DT_INT32, TensorShape({8})); + end.flat<int32>()(0) = 10 + kDim; + for (int i = 1; i < 7; ++i) { + end.flat<int32>()(i) = 1; + } + end.flat<int32>()(7) = 10 + size; + + Tensor strides(DT_INT32, TensorShape({8})); + for (int i = 0; i < 8; ++i) { + strides.flat<int32>()(i) = 1; + } + + Tensor input(dt, TensorShape({2*kDim, 1, 1, 1, 1, 1, 1, kMaxSize})); + input.flat<T>().setRandom(); + + Node* node; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "StridedSlice") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, begin)) + .Input(test::graph::Constant(g, end)) + .Input(test::graph::Constant(g, strides)) + .Attr("T", dt) + .Finalize(g, &node)); + + testing::BytesProcessed(static_cast<int64>(iters) * kDim * size * sizeof(T)); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); + testing::UseRealTime(); +} + static void BM_SliceFloat(int iters, int dim2) { SliceHelper<float>(iters, dim2); + Dim8SliceHelper<float>(iters, dim2); } BENCHMARK(BM_SliceFloat)->Arg(100)->Arg(1000)->Arg(10000); static void BM_SliceComplex64(int iters, int dim2) { SliceHelper<std::complex<float>>(iters, dim2); + Dim8SliceHelper<std::complex<float>>(iters, dim2); } BENCHMARK(BM_SliceComplex64)->Arg(100)->Arg(1000)->Arg(10000); static void BM_SliceBFloat16(int iters, int dim2) { SliceHelper<bfloat16>(iters, dim2); + Dim8SliceHelper<bfloat16>(iters, dim2); } BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000); |