diff options
author | 2016-09-15 15:29:30 -0800 | |
---|---|---|
committer | 2016-09-15 16:33:22 -0700 | |
commit | caceb02f75ff80a8e48440720cec3d7d6fa3297e (patch) | |
tree | 931c6b37cfd5881fe626bb246a73140ba60f3bfc /tensorflow/core/kernels/strided_slice_op_test.cc | |
parent | cb9d147e9c788cc60ebb255fd26971719c7e2db2 (diff) |
Add StridedSlice C++ shape fn and change python to use it.
Change: 133323531
Diffstat (limited to 'tensorflow/core/kernels/strided_slice_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/strided_slice_op_test.cc | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/strided_slice_op_test.cc b/tensorflow/core/kernels/strided_slice_op_test.cc index d4e057b49d..5a123559be 100644 --- a/tensorflow/core/kernels/strided_slice_op_test.cc +++ b/tensorflow/core/kernels/strided_slice_op_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/node_builder.h" @@ -30,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/strided_slice_op.h" namespace tensorflow { namespace { @@ -87,5 +89,38 @@ static void BM_SliceBFloat16(int iters, int dim2) { BENCHMARK(BM_SliceBFloat16)->Arg(100)->Arg(1000)->Arg(10000); +static void BM_ValidateStridedSliceOp(int iters) { + testing::StopTiming(); + int kDim = 100; + int kMaxSize = 15000; + int size = 100; + Tensor begin = test::AsTensor<int32>({10, 10}); + Tensor end = test::AsTensor<int32>({10 + kDim, 10 + size}); + Tensor strides = test::AsTensor<int32>({1, 1}); + TensorShape input_shape({2 * kDim, kMaxSize}); + + testing::StartTiming(); + for (int i = 0; i < iters; ++i) { + TensorShape processing_shape, final_shape; + bool is_identity = true, slice_dim0 = true, is_simple_slice = true; + gtl::InlinedVector<int64, 4> begin_out, end_out, strides_out; + const int32 begin_mask = 0; + const int32 end_mask = 0; + const int32 ellipsis_mask = 0; + const int32 new_axis_mask = 0; + const int32 shrink_axis_mask = 0; + + ShapeReadWriteFromTensorShape wrapped_processing_shape(&processing_shape); + ShapeReadWriteFromTensorShape wrapped_final_shape(&final_shape); + TF_CHECK_OK(ValidateStridedSliceOp( + begin, end, strides, ShapeReadWriteFromTensorShape(&input_shape), + begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, + &wrapped_processing_shape, &wrapped_final_shape, &is_identity, + &is_simple_slice, &slice_dim0, &begin_out, &end_out, &strides_out)); + } +} + +BENCHMARK(BM_ValidateStridedSliceOp); + } // namespace } // namespace tensorflow |