aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dequantize_op_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-08-17 11:44:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-17 11:49:01 -0700
commit8a898a481828885b68bc41f481f40fb2b83a8bd2 (patch)
tree8de26eef5adf5452f5c5f8e50b5d0738fb871bd4 /tensorflow/core/kernels/dequantize_op_test.cc
parent09c93e391caa33b75827bdc7d41ce140c937ec81 (diff)
Fix dequantize_op benchmark. The requirement is that the ranges come in as
scalars, not single-element vectors, but this was not enforced until shape inference was enabled recently for this code path. PiperOrigin-RevId: 165606693
Diffstat (limited to 'tensorflow/core/kernels/dequantize_op_test.cc')
-rw-r--r--tensorflow/core/kernels/dequantize_op_test.cc9
1 files changed, 4 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/dequantize_op_test.cc b/tensorflow/core/kernels/dequantize_op_test.cc
index 8992629d42..b6e9f64aaa 100644
--- a/tensorflow/core/kernels/dequantize_op_test.cc
+++ b/tensorflow/core/kernels/dequantize_op_test.cc
@@ -77,8 +77,8 @@ class DequantizeOpTest : public OpsTestBase {
}
TensorShape shape({static_cast<int64>(input.size())});
AddInputFromArray<T>(shape, input);
- AddInputFromArray<float>(TensorShape({1}), {min_range});
- AddInputFromArray<float>(TensorShape({1}), {max_range});
+ AddInputFromArray<float>(TensorShape({}), {min_range});
+ AddInputFromArray<float>(TensorShape({}), {max_range});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_FLOAT, shape);
ComputeDequantizeMinCombinedUsingEigen<T>(GetInput(0), min_range, max_range,
@@ -107,9 +107,8 @@ static void BM_DequantizeMinCombinedCpu(int iters) {
std::vector<T> inputs;
inputs.reserve(num_values);
for (int i = 0; i < num_values; ++i) inputs.push_back(i);
- ops::Dequantize(root, test::AsTensor<T>(inputs),
- test::AsTensor<float>({-1.5f}),
- test::AsTensor<float>({20.5f}),
+ ops::Dequantize(root, test::AsTensor<T>(inputs), test::AsScalar<float>(-1.5f),
+ test::AsScalar<float>(20.5f),
ops::Dequantize::Attrs().Mode("MIN_COMBINED"));
TF_CHECK_OK(root.status());
Graph* g = new Graph(OpRegistry::Global());