aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-19 08:57:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 09:00:21 -0700
commit316fee40d4978db2f6abbb5ff35cf8d979bee93e (patch)
tree7677176700a716c710a00cc807c9f7d951818dac /tensorflow/contrib
parent2f7c783d9ff5bc059fb58b875c9b9dae2fc96392 (diff)
Update TFLite "minimal" example
PiperOrigin-RevId: 201183828
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/examples/minimal/BUILD27
-rw-r--r--tensorflow/contrib/lite/examples/minimal/minimal.cc24
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc13
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.h3
4 files changed, 46 insertions, 21 deletions
diff --git a/tensorflow/contrib/lite/examples/minimal/BUILD b/tensorflow/contrib/lite/examples/minimal/BUILD
new file mode 100644
index 0000000000..b403628d6c
--- /dev/null
+++ b/tensorflow/contrib/lite/examples/minimal/BUILD
@@ -0,0 +1,27 @@
+# Description:
+# TensorFlow Lite minimal example.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
+load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts")
+
+tf_cc_binary(
+ name = "minimal",
+ srcs = [
+ "minimal.cc",
+ ],
+ linkopts = tflite_linkopts() + select({
+ "//tensorflow:android": [
+ "-pie", # Android 5.0 and later supports only PIE
+ "-lm", # some builtin ops, e.g., tanh, need -lm
+ ],
+ "//conditions:default": [],
+ }),
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:builtin_ops",
+ ],
+)
diff --git a/tensorflow/contrib/lite/examples/minimal/minimal.cc b/tensorflow/contrib/lite/examples/minimal/minimal.cc
index 8b0ace96cc..8b65cde7b7 100644
--- a/tensorflow/contrib/lite/examples/minimal/minimal.cc
+++ b/tensorflow/contrib/lite/examples/minimal/minimal.cc
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/lite/model.h"
+#include <cstdio>
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/register.h"
-#include <cstdio>
+#include "tensorflow/contrib/lite/model.h"
+#include "tensorflow/contrib/lite/optional_debug_tools.h"
// This is an example that is minimal to read a model
// from disk and perform inference. There is no data being loaded
@@ -29,14 +30,13 @@ limitations under the License.
using namespace tflite;
-#define TFLITE_MINIMAL_CHECK(x) \
- if(!(x)) { \
- fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
- exit(1); \
+#define TFLITE_MINIMAL_CHECK(x) \
+ if (!(x)) { \
+ fprintf(stderr, "Error at %s:%d\n", __FILE__, __LINE__); \
+ exit(1); \
}
-
-int main(int argc, char *argv[]) {
+int main(int argc, char* argv[]) {
if(argc != 2) {
fprintf(stderr, "minimal <tflite model>\n");
return 1;
@@ -44,8 +44,8 @@ int main(int argc, char *argv[]) {
const char* filename = argv[1];
// Load model
- std::unique_ptr<tflite::FlatBufferModel> model
- = tflite::FlatBufferModel::BuildFromFile(filename);
+ std::unique_ptr<tflite::FlatBufferModel> model =
+ tflite::FlatBufferModel::BuildFromFile(filename);
TFLITE_MINIMAL_CHECK(model != nullptr);
// Build the interpreter
@@ -57,12 +57,16 @@ int main(int argc, char *argv[]) {
// Allocate tensor buffers.
TFLITE_MINIMAL_CHECK(interpreter->AllocateTensors() == kTfLiteOk);
+ printf("=== Pre-invoke Interpreter State ===\n");
+ tflite::PrintInterpreterState(interpreter.get());
// Fill input buffers
// TODO(user): Insert code to fill input tensors
// Run inference
TFLITE_MINIMAL_CHECK(interpreter->Invoke() == kTfLiteOk);
+ printf("\n\n=== Post-invoke Interpreter State ===\n");
+ tflite::PrintInterpreterState(interpreter.get());
// Read output buffers
// TODO(user): Insert getting data out code.
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index 3af809a2a1..99c35b9caf 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -84,13 +84,13 @@ void PrintInterpreterState(Interpreter* interpreter) {
for (int tensor_index = 0; tensor_index < interpreter->tensors_size();
tensor_index++) {
TfLiteTensor* tensor = interpreter->tensor(tensor_index);
- printf("Tensor %3d %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
- TensorTypeName(tensor->type), AllocTypeName(tensor->allocation_type),
- tensor->bytes, float(tensor->bytes) / float(1 << 20));
+ printf("Tensor %3d %-20s %10s %15s %10zu bytes (%4.1f MB) ", tensor_index,
+ tensor->name, TensorTypeName(tensor->type),
+ AllocTypeName(tensor->allocation_type), tensor->bytes,
+ (static_cast<float>(tensor->bytes) / (1 << 20)));
PrintTfLiteIntVector(tensor->dims);
- printf("\n");
}
-
+ printf("\n");
for (int node_index = 0; node_index < interpreter->nodes_size();
node_index++) {
const std::pair<TfLiteNode, TfLiteRegistration>* node_and_reg =
@@ -106,7 +106,4 @@ void PrintInterpreterState(Interpreter* interpreter) {
}
}
-// Prints a dump of what tensors and what nodes are in the interpreter.
-TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
-
} // namespace tflite
diff --git a/tensorflow/contrib/lite/optional_debug_tools.h b/tensorflow/contrib/lite/optional_debug_tools.h
index 1b6998cda3..7fb4b8d8b7 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.h
+++ b/tensorflow/contrib/lite/optional_debug_tools.h
@@ -24,9 +24,6 @@ namespace tflite {
// Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(Interpreter* interpreter);
-// Prints a dump of what tensors and what nodes are in the interpreter.
-TfLiteStatus ValidateInterpreterState(const Interpreter* interpreter);
-
} // namespace tflite
#endif // TENSORFLOW_CONTRIB_LITE_DEBUG_TOOLS_H_