aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/parse_testdata.cc
blob: 389688d552051ea735ce71533943af33df5059ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Parses tflite example input data.
// Format is ASCII
// TODO(aselle): Switch to protobuf, but the android team requested a simple
// ASCII file.
#include "tensorflow/contrib/lite/testing/parse_testdata.h"

#include <cinttypes>
#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <iostream>
#include <streambuf>

#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/testing/message.h"
#include "tensorflow/contrib/lite/testing/split.h"

namespace tflite {
namespace testing {
namespace {

// Fatal error if parse error occurs
#define PARSE_CHECK_EQ(filename, current_line, x, y)                         \
  if ((x) != (y)) {                                                          \
    fprintf(stderr, "Parse Error @ %s:%d\n  File %s\n  Line %d, %s != %s\n", \
            __FILE__, __LINE__, filename, current_line + 1, #x, #y);         \
    return kTfLiteError;                                                     \
  }

// Breakup a "," delimited line into a std::vector<std::string>.
// This is extremely inefficient, and just used for testing code.
// TODO(aselle): replace with absl when we use it.
std::vector<std::string> ParseLine(const std::string& line) {
  size_t pos = 0;
  std::vector<std::string> elements;
  while (true) {
    size_t end = line.find(',', pos);
    if (end == std::string::npos) {
      elements.push_back(line.substr(pos));
      break;
    } else {
      elements.push_back(line.substr(pos, end - pos));
    }
    pos = end + 1;
  }
  return elements;
}

}  // namespace

// Given a `filename`, produce a vector of Examples corresopnding
// to test cases that can be applied to a tflite model.
TfLiteStatus ParseExamples(const char* filename,
                           std::vector<Example>* examples) {
  std::ifstream fp(filename);
  if (!fp.good()) {
    fprintf(stderr, "Could not read '%s'\n", filename);
    return kTfLiteError;
  }
  std::string str((std::istreambuf_iterator<char>(fp)),
                  std::istreambuf_iterator<char>());
  size_t pos = 0;

  // \n and , delimit parse a file.
  std::vector<std::vector<std::string>> csv;
  while (true) {
    size_t end = str.find('\n', pos);

    if (end == std::string::npos) {
      csv.emplace_back(ParseLine(str.substr(pos)));
      break;
    }
    csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
    pos = end + 1;
  }

  int current_line = 0;
  PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
  int example_count = std::stoi(csv[0][1]);
  current_line++;

  auto parse_tensor = [&filename, &current_line,
                       &csv](FloatTensor* tensor_ptr) {
    PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
    current_line++;
    // parse shape
    PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
    size_t elements = 1;
    FloatTensor& tensor = *tensor_ptr;

    for (size_t i = 1; i < csv[current_line].size(); i++) {
      const auto& shape_part_to_parse = csv[current_line][i];
      if (shape_part_to_parse.empty()) {
        // Case of a 0-dimensional shape
        break;
      }
      int shape_part = std::stoi(shape_part_to_parse);
      elements *= shape_part;
      tensor.shape.push_back(shape_part);
    }
    current_line++;
    // parse data
    PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
                   elements);
    for (size_t i = 1; i < csv[current_line].size(); i++) {
      tensor.flat_data.push_back(std::stof(csv[current_line][i]));
    }
    current_line++;

    return kTfLiteOk;
  };

  for (int example_idx = 0; example_idx < example_count; example_idx++) {
    Example example;
    PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
    int inputs = std::stoi(csv[current_line][1]);
    current_line++;
    // parse dtype
    for (int input_index = 0; input_index < inputs; input_index++) {
      example.inputs.push_back(FloatTensor());
      TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
    }

    PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
    int outputs = std::stoi(csv[current_line][1]);
    current_line++;
    for (int input_index = 0; input_index < outputs; input_index++) {
      example.outputs.push_back(FloatTensor());
      TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
    }
    examples->emplace_back(example);
  }
  return kTfLiteOk;
}

TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
                         const Example& example) {
  // Resize inputs to match example & allocate.
  for (size_t i = 0; i < interpreter->inputs().size(); i++) {
    int input_index = interpreter->inputs()[i];

    TF_LITE_ENSURE_STATUS(
        interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
  }
  TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
  // Copy data into tensors.
  for (size_t i = 0; i < interpreter->inputs().size(); i++) {
    int input_index = interpreter->inputs()[i];
    if (float* data = interpreter->typed_tensor<float>(input_index)) {
      for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
        data[idx] = example.inputs[i].flat_data[idx];
      }
    } else if (int32_t* data =
                   interpreter->typed_tensor<int32_t>(input_index)) {
      for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
        data[idx] = example.inputs[i].flat_data[idx];
      }
    } else if (int64_t* data =
                   interpreter->typed_tensor<int64_t>(input_index)) {
      for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
        data[idx] = example.inputs[i].flat_data[idx];
      }
    } else {
      fprintf(stderr, "input[%zu] was not float or int data\n", i);
      return kTfLiteError;
    }
  }
  return kTfLiteOk;
}

TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
                          const Example& example) {
  constexpr double kRelativeThreshold = 1e-2f;
  constexpr double kAbsoluteThreshold = 1e-4f;

  ErrorReporter* context = DefaultErrorReporter();
  int model_outputs = interpreter->outputs().size();
  TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
  for (size_t i = 0; i < interpreter->outputs().size(); i++) {
    bool tensors_differ = false;
    int output_index = interpreter->outputs()[i];
    if (const float* data = interpreter->typed_tensor<float>(output_index)) {
      for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
        float computed = data[idx];
        float reference = example.outputs[0].flat_data[idx];
        float diff = std::abs(computed - reference);
        // For very small numbers, try absolute error, otherwise go with
        // relative.
        bool local_tensors_differ =
            std::abs(reference) < kRelativeThreshold
                ? diff > kAbsoluteThreshold
                : diff > kRelativeThreshold * std::abs(reference);
        if (local_tensors_differ) {
          fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
                  i, idx, data[idx], reference);
          tensors_differ = local_tensors_differ;
        }
      }
    } else if (const int32_t* data =
                   interpreter->typed_tensor<int32_t>(output_index)) {
      for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
        int32_t computed = data[idx];
        int32_t reference = example.outputs[0].flat_data[idx];
        if (std::abs(computed - reference) > 0) {
          fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
                  i, idx, computed, reference);
          tensors_differ = true;
        }
      }
    } else if (const int64_t* data =
                   interpreter->typed_tensor<int64_t>(output_index)) {
      for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
        int64_t computed = data[idx];
        int64_t reference = example.outputs[0].flat_data[idx];
        if (std::abs(computed - reference) > 0) {
          fprintf(stderr,
                  "output[%zu][%zu] did not match %" PRId64
                  " vs reference %" PRId64 "\n",
                  i, idx, computed, reference);
          tensors_differ = true;
        }
      }
    } else {
      fprintf(stderr, "output[%zu] was not float or int data\n", i);
      return kTfLiteError;
    }
    fprintf(stderr, "\n");
    if (tensors_differ) return kTfLiteError;
  }
  return kTfLiteOk;
}

// Process an 'invoke' message, triggering execution of the test runner, as
// well as verification of outputs. An 'invoke' message looks like:
//   invoke {
//     id: xyz
//     input: 1,2,1,1,1,2,3,4
//     output: 4,5,6
//   }
class Invoke : public Message {
 public:
  explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
    expected_inputs_ = test_runner->GetInputs();
    expected_outputs_ = test_runner->GetOutputs();
  }

  void SetField(const std::string& name, const std::string& value) override {
    if (name == "id") {
      test_runner_->SetInvocationId(value);
    } else if (name == "input") {
      if (expected_inputs_.empty()) {
        return test_runner_->Invalidate("Too many inputs");
      }
      test_runner_->SetInput(*expected_inputs_.begin(), value);
      expected_inputs_.erase(expected_inputs_.begin());
    } else if (name == "output") {
      if (expected_outputs_.empty()) {
        return test_runner_->Invalidate("Too many outputs");
      }
      test_runner_->SetExpectation(*expected_outputs_.begin(), value);
      expected_outputs_.erase(expected_outputs_.begin());
    }
  }
  void Finish() override {
    test_runner_->Invoke();
    test_runner_->CheckResults();
  }

 private:
  std::vector<int> expected_inputs_;
  std::vector<int> expected_outputs_;

  TestRunner* test_runner_;
};

// Process an 'reshape' message, triggering resizing of the input tensors via
// the test runner. A 'reshape' message looks like:
//   reshape {
//     input: 1,2,1,1,1,2,3,4
//   }
class Reshape : public Message {
 public:
  explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
    expected_inputs_ = test_runner->GetInputs();
  }

  void SetField(const std::string& name, const std::string& value) override {
    if (name == "input") {
      if (expected_inputs_.empty()) {
        return test_runner_->Invalidate("Too many inputs to reshape");
      }
      test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
      expected_inputs_.erase(expected_inputs_.begin());
    }
  }

 private:
  std::vector<int> expected_inputs_;
  TestRunner* test_runner_;
};

// This is the top-level message in a test file.
class TestData : public Message {
 public:
  explicit TestData(TestRunner* test_runner)
      : test_runner_(test_runner), num_invocations_(0), max_invocations_(-1) {}
  void SetMaxInvocations(int max) { max_invocations_ = max; }
  void SetField(const std::string& name, const std::string& value) override {
    if (name == "load_model") {
      test_runner_->LoadModel(value);
    } else if (name == "init_state") {
      test_runner_->AllocateTensors();
      for (int id : Split<int>(value, ",")) {
        test_runner_->ResetTensor(id);
      }
    }
  }
  Message* AddChild(const std::string& s) override {
    if (s == "invoke") {
      test_runner_->AllocateTensors();
      if (max_invocations_ == -1 || num_invocations_ < max_invocations_) {
        ++num_invocations_;
        return Store(new Invoke(test_runner_));
      } else {
        return nullptr;
      }
    } else if (s == "reshape") {
      return Store(new Reshape(test_runner_));
    }
    return nullptr;
  }

 private:
  TestRunner* test_runner_;
  int num_invocations_;
  int max_invocations_;
};

bool ParseAndRunTests(std::istream* input, TestRunner* test_runner,
                      int max_invocations) {
  TestData test_data(test_runner);
  test_data.SetMaxInvocations(max_invocations);
  Message::Read(input, &test_data);
  return test_runner->IsValid() && test_runner->GetOverallSuccess();
}

}  // namespace testing
}  // namespace tflite