diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-27 06:12:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 06:16:42 -0700 |
commit | abf26356209cba1ba895a06d9ce55ad01dad7fc6 (patch) | |
tree | 5ef1c907a30bf89d08ba241ef985b19938427420 /tensorflow/contrib/lite/kernels/transpose_test.cc | |
parent | 19d8963bc0ea64e10ff08ad4e7cc76813a182196 (diff) |
Update kernel evals to use new kernel signatures.
PiperOrigin-RevId: 214763814
Diffstat (limited to 'tensorflow/contrib/lite/kernels/transpose_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/transpose_test.cc | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc index 337bc144b9..79ef0a7c56 100644 --- a/tensorflow/contrib/lite/kernels/transpose_test.cc +++ b/tensorflow/contrib/lite/kernels/transpose_test.cc @@ -51,21 +51,21 @@ void RunTestPermutation(const std::vector<int>& shape, reversed_perms[k] = k; } - // Make input and output dims (i.e. reversed shape and dest_shape). - Dims<4> input_dims = GetTensorDims(shape); - Dims<4> output_dims; - for (int i = 0; i < 4; i++) { - output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]]; + // Make input and output shapes. + const RuntimeShape input_shape = GetTensorShape(shape); + RuntimeShape output_shape(perms.size()); + for (int i = 0; i < perms.size(); i++) { + output_shape.SetDim(i, input_shape.Dims(perms[i])); } - output_dims.strides[0] = 1; - for (int k = 1; k < 4; k++) { - output_dims.strides[k] = - output_dims.strides[k - 1] * output_dims.sizes[k - 1]; + + TransposeParams params; + params.perm_count = perms.size(); + for (int i = 0; i < perms.size(); ++i) { + params.perm[i] = perms[i]; } - reference_ops::Transpose<float>(input.data(), input_dims, - input_transposed->data(), output_dims, - reversed_perms); + reference_ops::Transpose<float>(params, input_shape, input.data(), + output_shape, input_transposed->data()); } TEST(TransposeTest, TestRefOps1D) { |