aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/test_util.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-04 18:49:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-05 08:45:32 -0700
commit5fb53fe69afe7f9106a8bcb5632cea23cf227d78 (patch)
tree2658e05fa2481666efbea50c56909abfec3f938f /tensorflow/contrib/lite/kernels/test_util.h
parentdd5ef1b9fc22b37e5eec87d659a3af064ca54b8b (diff)
add support for PadV2
PiperOrigin-RevId: 195503894
Diffstat (limited to 'tensorflow/contrib/lite/kernels/test_util.h')
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h85
1 files changed, 81 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index 6fb6fe27eb..6a9fdf1112 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -116,9 +116,14 @@ class SingleOpModel {
int AddInput(TensorType type) { return AddInput(TensorData{type}); }
int AddInput(const TensorData& t);
- // Add a Tensor containing const data and return the tensor id.
- int AddConstInput(TensorType type, std::initializer_list<int> data,
- std::initializer_list<int> shape);
+ // Templated version of AddConstInput().
+ template <typename T>
+ int AddConstInput(TensorType type, std::initializer_list<T> data,
+ std::initializer_list<int> shape) {
+ int id = AddTensor(TensorData{type, shape}, data);
+ inputs_.push_back(id);
+ return id;
+ }
// Add a null input tensor (optional input) and return kOptionalTensor.
int AddNullInput();
@@ -224,7 +229,79 @@ class SingleOpModel {
std::unique_ptr<OpResolver> resolver_;
private:
- int AddTensor(TensorData t, std::initializer_list<int> data);
+ // TODO(gavinbelson): sync this method with
+ // //tensorflow/contrib/lite/kernels/internal/quantization_util.h?l=31
+ template <typename T>
+ std::pair<float, int32_t> QuantizationParams(float f_min, float f_max) {
+ // These are required by many quantized operations.
+ CHECK_LE(f_min, 0);
+ CHECK_GE(f_max, 0);
+ T q_min = std::numeric_limits<T>::min();
+ T q_max = std::numeric_limits<T>::max();
+ float range = q_max - q_min;
+ float scale = (f_max - f_min) / range;
+ int32_t zero_point = std::min(
+ q_max,
+ std::max(q_min, static_cast<T>(std::round(q_min - f_min / scale))));
+ return {scale, zero_point};
+ }
+
+ template <typename T>
+ int AddTensor(TensorData t, std::initializer_list<T> data) {
+ int id = tensors_.size();
+
+ // This is slightly different depending on whether we are adding a
+ // quantized or a regular tensor.
+ bool is_quantized = (t.min != 0 || t.max != 0 || t.scale != 0);
+
+ flatbuffers::Offset<QuantizationParameters> q_params = 0;
+
+ if (is_quantized) {
+ if (t.min != 0 || t.max != 0) {
+ if (t.type == TensorType_UINT8) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<uint8_t>(t.min, t.max);
+ } else if (t.type == TensorType_INT32) {
+ std::tie(t.scale, t.zero_point) =
+ QuantizationParams<int32_t>(t.min, t.max);
+ } else {
+ LOG(FATAL) << "No support for the requested quantized type";
+ }
+ t.min = 0;
+ t.max = 0;
+ }
+
+ q_params = CreateQuantizationParameters(
+ builder_, /*min=*/0, /*max=*/0,
+ builder_.CreateVector<float>({t.scale}),
+ builder_.CreateVector<int64_t>({t.zero_point}));
+ }
+
+ int buffer_id = 0;
+ if (data.size()) {
+ // Initialize buffers list with empty buffer to allow for non-const
+ // tensors.
+ if (buffers_.empty()) {
+ buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({})));
+ }
+
+ // Add data as a Buffer to buffers list.
+ buffer_id = buffers_.size();
+ auto data_buffer =
+ builder_.CreateVector(reinterpret_cast<const uint8_t*>(data.begin()),
+ sizeof(T) * data.size());
+ buffers_.push_back(CreateBuffer(builder_, data_buffer));
+ }
+
+ tensors_.push_back(CreateTensor(builder_,
+ builder_.CreateVector<int>(t.shape), t.type,
+ /*buffer=*/buffer_id,
+ /*name=*/0, q_params));
+
+ tensor_data_[id] = t;
+
+ return id;
+ }
std::map<int, TensorData> tensor_data_;
std::vector<int32_t> inputs_;