diff options
author | 2017-08-17 11:44:43 -0700 | |
---|---|---|
committer | 2017-08-17 11:49:01 -0700 | |
commit | 8a898a481828885b68bc41f481f40fb2b83a8bd2 (patch) | |
tree | 8de26eef5adf5452f5c5f8e50b5d0738fb871bd4 /tensorflow/core/kernels/dequantize_op_test.cc | |
parent | 09c93e391caa33b75827bdc7d41ce140c937ec81 (diff) |
Fix dequantize_op benchmark. The requirement is that the ranges come in as
scalars, not single-element vectors, but this was not enforced until shape
inference was enabled recently for this code path.
PiperOrigin-RevId: 165606693
Diffstat (limited to 'tensorflow/core/kernels/dequantize_op_test.cc')
-rw-r--r-- | tensorflow/core/kernels/dequantize_op_test.cc | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc index 8992629d42..b6e9f64aaa 100644 --- a/tensorflow/core/kernels/dequantize_op_test.cc +++ b/tensorflow/core/kernels/dequantize_op_test.cc @@ -77,8 +77,8 @@ class DequantizeOpTest : public OpsTestBase { } TensorShape shape({static_cast<int64>(input.size())}); AddInputFromArray<T>(shape, input); - AddInputFromArray<float>(TensorShape({1}), {min_range}); - AddInputFromArray<float>(TensorShape({1}), {max_range}); + AddInputFromArray<float>(TensorShape({}), {min_range}); + AddInputFromArray<float>(TensorShape({}), {max_range}); TF_ASSERT_OK(RunOpKernel()); Tensor expected(allocator(), DT_FLOAT, shape); ComputeDequantizeMinCombinedUsingEigen<T>(GetInput(0), min_range, max_range, @@ -107,9 +107,8 @@ static void BM_DequantizeMinCombinedCpu(int iters) { std::vector<T> inputs; inputs.reserve(num_values); for (int i = 0; i < num_values; ++i) inputs.push_back(i); - ops::Dequantize(root, test::AsTensor<T>(inputs), - test::AsTensor<float>({-1.5f}), - test::AsTensor<float>({20.5f}), + ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f), + test::AsScalar<float>(20.5f), ops::Dequantize::Attrs().Mode("MIN_COMBINED")); TF_CHECK_OK(root.status()); Graph* g = new Graph(OpRegistry::Global()); |