aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/transpose_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-27 06:12:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 06:16:42 -0700
commitabf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch)
tree5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/transpose_test.cc
parent19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff)
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/transpose_test.cc24
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc
index 337bc144b9..79ef0a7c56 100644
--- a/tensorflow/contrib/lite/kernels/transpose_test.cc
+++ b/tensorflow/contrib/lite/kernels/transpose_test.cc
@@ -51,21 +51,21 @@ void RunTestPermutation(const std::vector<int>& shape,
reversed_perms[k] = k;
}
- // Make input and output dims (i.e. reversed shape and dest_shape).
- Dims<4> input_dims = GetTensorDims(shape);
- Dims<4> output_dims;
- for (int i = 0; i < 4; i++) {
- output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]];
+ // Make input and output shapes.
+ const RuntimeShape input_shape = GetTensorShape(shape);
+ RuntimeShape output_shape(perms.size());
+ for (int i = 0; i < perms.size(); i++) {
+ output_shape.SetDim(i, input_shape.Dims(perms[i]));
}
- output_dims.strides[0] = 1;
- for (int k = 1; k < 4; k++) {
- output_dims.strides[k] =
- output_dims.strides[k - 1] * output_dims.sizes[k - 1];
+
+ TransposeParams params;
+ params.perm_count = perms.size();
+ for (int i = 0; i < perms.size(); ++i) {
+ params.perm[i] = perms[i];
}
- reference_ops::Transpose<float>(input.data(), input_dims,
- input_transposed->data(), output_dims,
- reversed_perms);
+ reference_ops::Transpose<float>(params, input_shape, input.data(),
+ output_shape, input_transposed->data());
}
TEST(TransposeTest, TestRefOps1D) {