aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/label_image
diff options
context:
space:
mode:
authorGravatar Pete Warden <pete@petewarden.com>2016-01-16 14:59:03 -0800
committerGravatar Manjunath Kudlur <keveman@gmail.com>2016-01-18 12:21:50 -0800
commit8ffdf4a33941e8b1f259c3834d4376cd1cdf3790 (patch)
treef723f7af3e6ea63b1a363d40eb340b47f08897b0 /tensorflow/examples/label_image
parent48a95edc43fa420342c8002b3643afbbbffb0065 (diff)
Moved the flag parsing into a separate module.
Tidying up the label_image example so that the argument handling is taken care of outside of the main flow of the code, to make understanding it easier. Change: 112333770
Diffstat (limited to 'tensorflow/examples/label_image')
-rw-r--r--tensorflow/examples/label_image/BUILD17
-rw-r--r--tensorflow/examples/label_image/command_line_flags.cc135
-rw-r--r--tensorflow/examples/label_image/command_line_flags.h59
-rw-r--r--tensorflow/examples/label_image/command_line_flags_test.cc79
-rw-r--r--tensorflow/examples/label_image/main.cc99
5 files changed, 308 insertions, 81 deletions
diff --git a/tensorflow/examples/label_image/BUILD b/tensorflow/examples/label_image/BUILD
index c13b530e85..1498a21c53 100644
--- a/tensorflow/examples/label_image/BUILD
+++ b/tensorflow/examples/label_image/BUILD
@@ -12,7 +12,10 @@ exports_files(["LICENSE"])
cc_binary(
name = "label_image",
- srcs = ["main.cc"],
+ srcs = [
+ "command_line_flags.cc",
+ "main.cc",
+ ],
linkopts = ["-lm"],
deps = [
"//tensorflow/cc:cc_ops",
@@ -20,6 +23,18 @@ cc_binary(
],
)
+cc_test(
+ name = "command_line_flags_test",
+ srcs = [
+ "command_line_flags.cc",
+ "command_line_flags_test.cc",
+ ],
+ deps = [
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/platform/default/build_config:test_main",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/examples/label_image/command_line_flags.cc b/tensorflow/examples/label_image/command_line_flags.cc
new file mode 100644
index 0000000000..9806ccb624
--- /dev/null
+++ b/tensorflow/examples/label_image/command_line_flags.cc
@@ -0,0 +1,135 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/examples/label_image/command_line_flags.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/logging.h"
+
+using tensorflow::string;
+
+namespace {
+
+bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ string* dst, bool* value_parsing_ok) {
+ *value_parsing_ok = true;
+ if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ *dst = arg.ToString();
+ return true;
+ }
+
+ return false;
+}
+
+bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ tensorflow::int32* dst, bool* value_parsing_ok) {
+ *value_parsing_ok = true;
+ if (arg.Consume("--") && arg.Consume(flag) && arg.Consume("=")) {
+ char extra;
+ if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
+ LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
+ << ".";
+ *value_parsing_ok = false;
+ }
+ return true;
+ }
+
+ return false;
+}
+
+bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
+ bool* dst, bool* value_parsing_ok) {
+ *value_parsing_ok = true;
+ if (arg.Consume("--") && arg.Consume(flag)) {
+ if (arg.empty()) {
+ *dst = true;
+ return true;
+ }
+
+ if (arg == "=true") {
+ *dst = true;
+ return true;
+ } else if (arg == "=false") {
+ *dst = false;
+ return true;
+ } else {
+ LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
+ << ".";
+ *value_parsing_ok = false;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+Flag::Flag(const char* name, tensorflow::int32* dst)
+ : name_(name), type_(TYPE_INT), int_value_(dst) {}
+
+Flag::Flag(const char* name, bool* dst)
+ : name_(name), type_(TYPE_BOOL), bool_value_(dst) {}
+
+Flag::Flag(const char* name, string* dst)
+ : name_(name), type_(TYPE_STRING), string_value_(dst) {}
+
+bool Flag::Parse(string arg, bool* value_parsing_ok) const {
+ bool result = false;
+ if (type_ == TYPE_INT) {
+ result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok);
+ } else if (type_ == TYPE_BOOL) {
+ result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
+ } else if (type_ == TYPE_STRING) {
+ result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
+ }
+ return result;
+}
+
+bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list) {
+ bool result = true;
+ std::vector<char*> unknown_flags;
+ for (int i = 1; i < *argc; ++i) {
+ if (string(argv[i]) == "--") {
+ while (i < *argc) {
+ unknown_flags.push_back(argv[i]);
+ ++i;
+ }
+ break;
+ }
+
+ bool was_found = false;
+ for (const Flag& flag : flag_list) {
+ bool value_parsing_ok;
+ was_found = flag.Parse(argv[i], &value_parsing_ok);
+ if (!value_parsing_ok) {
+ result = false;
+ }
+ if (was_found) {
+ break;
+ }
+ }
+ if (!was_found) {
+ unknown_flags.push_back(argv[i]);
+ }
+ }
+ // Passthrough any extra flags.
+ int dst = 1; // Skip argv[0]
+ for (char* f : unknown_flags) {
+ argv[dst++] = f;
+ }
+ argv[dst++] = nullptr;
+ *argc = unknown_flags.size() + 1;
+ return result;
+}
diff --git a/tensorflow/examples/label_image/command_line_flags.h b/tensorflow/examples/label_image/command_line_flags.h
new file mode 100644
index 0000000000..dcc8f3fdde
--- /dev/null
+++ b/tensorflow/examples/label_image/command_line_flags.h
@@ -0,0 +1,59 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H
+#define TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H
+
+#include <vector>
+#include "tensorflow/core/platform/types.h"
+
+// This is a simple command-line argument parsing module to help us handle
+// example parameters. The recommended way of using it is with local variables
+// and an initializer list of Flag objects, for example:
+//
+// int some_int = 10;
+// bool some_switch = false;
+// string some_name = "something";
+// bool parsed_values_ok = ParseFlags(&argc, argv, {
+// Flag("some_int", &some_int),
+// Flag("some_switch", &some_switch),
+// Flag("some_name", &some_name)});
+//
+// The argc and argv values are adjusted by the Parse function so all that
+// remains is the program name (at argv[0]) and any unknown arguments fill the
+// rest of the array. This means you can check for flags that weren't understood
+// by seeing if argv is greater than 1.
+// The result indicates if there were any errors parsing the values that were
+// passed to the command-line switches. For example, --some_int=foo would return
+// false because the argument is expected to be an integer.
+class Flag {
+ public:
+ Flag(const char* name, tensorflow::int32* dst1);
+ Flag(const char* name, bool* dst);
+ Flag(const char* name, tensorflow::string* dst);
+
+ bool Parse(tensorflow::string arg, bool* value_parsing_ok) const;
+
+ private:
+ tensorflow::string name_;
+ enum { TYPE_INT, TYPE_BOOL, TYPE_STRING } type_;
+ int* int_value_;
+ bool* bool_value_;
+ tensorflow::string* string_value_;
+};
+
+bool ParseFlags(int* argc, char** argv, const std::vector<Flag>& flag_list);
+
+#endif // TENSORFLOW_EXAMPLES_LABEL_IMAGE_COMMAND_LINE_FLAGS_H
diff --git a/tensorflow/examples/label_image/command_line_flags_test.cc b/tensorflow/examples/label_image/command_line_flags_test.cc
new file mode 100644
index 0000000000..dd9cda5bf6
--- /dev/null
+++ b/tensorflow/examples/label_image/command_line_flags_test.cc
@@ -0,0 +1,79 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/examples/label_image/command_line_flags.h"
+#include <gtest/gtest.h>
+
+namespace {
+// The returned array is only valid for the lifetime of the input vector.
+// We're using const casting because we need to pass in an argv-style array of
+// char* pointers for the API, even though we know they won't be altered.
+std::vector<char*> CharPointerVectorFromStrings(
+ const std::vector<tensorflow::string>& strings) {
+ std::vector<char*> result;
+ for (const tensorflow::string& string : strings) {
+ result.push_back(const_cast<char*>(string.c_str()));
+ }
+ return result;
+}
+}
+
+TEST(CommandLineFlagsTest, BasicUsage) {
+ int some_int = 10;
+ bool some_switch = false;
+ string some_name = "something";
+ int argc = 4;
+ std::vector<tensorflow::string> argv_strings = {
+ "program_name", "--some_int=20", "--some_switch",
+ "--some_name=somethingelse"};
+ std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int),
+ Flag("some_switch", &some_switch),
+ Flag("some_name", &some_name)});
+ EXPECT_EQ(true, parsed_ok);
+ EXPECT_EQ(20, some_int);
+ EXPECT_EQ(true, some_switch);
+ EXPECT_EQ("somethingelse", some_name);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, BadIntValue) {
+ int some_int = 10;
+ int argc = 2;
+ std::vector<tensorflow::string> argv_strings = {"program_name",
+ "--some_int=notanumber"};
+ std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ ParseFlags(&argc, argv_array.data(), {Flag("some_int", &some_int)});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(10, some_int);
+ EXPECT_EQ(argc, 1);
+}
+
+TEST(CommandLineFlagsTest, BadBoolValue) {
+ bool some_switch = false;
+ int argc = 2;
+ std::vector<tensorflow::string> argv_strings = {"program_name",
+ "--some_switch=notabool"};
+ std::vector<char*> argv_array = CharPointerVectorFromStrings(argv_strings);
+ bool parsed_ok =
+ ParseFlags(&argc, argv_array.data(), {Flag("some_switch", &some_switch)});
+
+ EXPECT_EQ(false, parsed_ok);
+ EXPECT_EQ(false, some_switch);
+ EXPECT_EQ(argc, 1);
+}
diff --git a/tensorflow/examples/label_image/main.cc b/tensorflow/examples/label_image/main.cc
index 9bbc21a575..4460fefa3f 100644
--- a/tensorflow/examples/label_image/main.cc
+++ b/tensorflow/examples/label_image/main.cc
@@ -49,6 +49,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/examples/label_image/command_line_flags.h"
// These are all common classes it's handy to reference with no namespace.
using tensorflow::Tensor;
@@ -223,50 +224,6 @@ Status CheckTopLabel(const std::vector<Tensor>& outputs, int expected,
return Status::OK();
}
-namespace {
-
-bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- string* dst) {
- if (arg.Consume(flag) && arg.Consume("=")) {
- *dst = arg.ToString();
- return true;
- }
-
- return false;
-}
-
-bool ParseInt32Flag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- int32* dst) {
- if (arg.Consume(flag) && arg.Consume("=")) {
- char extra;
- return (sscanf(arg.data(), "%d%c", dst, &extra) == 1);
- }
-
- return false;
-}
-
-bool ParseBoolFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
- bool* dst) {
- if (arg.Consume(flag)) {
- if (arg.empty()) {
- *dst = true;
- return true;
- }
-
- if (arg == "=true") {
- *dst = true;
- return true;
- } else if (arg == "=false") {
- *dst = false;
- return true;
- }
- }
-
- return false;
-}
-
-} // namespace
-
int main(int argc, char* argv[]) {
// These are the command-line flags the program can understand.
// They define where the graph and input data is located, and what kind of
@@ -283,51 +240,33 @@ int main(int argc, char* argv[]) {
int32 input_height = 299;
int32 input_mean = 128;
int32 input_std = 128;
-
string input_layer = "Mul";
string output_layer = "softmax";
bool self_test = false;
string root_dir = "";
-
- std::vector<char*> unknown_flags;
- for (int i = 1; i < argc; ++i) {
- if (string(argv[i]) == "--") {
- while (i < argc) {
- unknown_flags.push_back(argv[i]);
- ++i;
- }
- break;
- }
-
- if (ParseStringFlag(argv[i], "--image", &image) ||
- ParseStringFlag(argv[i], "--graph", &graph) ||
- ParseStringFlag(argv[i], "--labels", &labels) ||
- ParseInt32Flag(argv[i], "--input_width", &input_width) ||
- ParseInt32Flag(argv[i], "--input_height", &input_height) ||
- ParseInt32Flag(argv[i], "--input_mean", &input_mean) ||
- ParseInt32Flag(argv[i], "--input_std", &input_std) ||
- ParseStringFlag(argv[i], "--input_layer", &input_layer) ||
- ParseStringFlag(argv[i], "--output_layer", &output_layer) ||
- ParseBoolFlag(argv[i], "--self_test", &self_test) ||
- ParseStringFlag(argv[i], "--root_dir", &root_dir)) {
- continue;
- }
-
- fprintf(stderr, "Unknown flag: %s\n", argv[i]);
+ const bool parse_result =
+ ParseFlags(&argc, argv, {Flag("image", &image), //
+ Flag("graph", &graph), //
+ Flag("labels", &labels), //
+ Flag("input_width", &input_width), //
+ Flag("input_height", &input_height), //
+ Flag("input_mean", &input_mean), //
+ Flag("input_std", &input_std), //
+ Flag("input_layer", &input_layer), //
+ Flag("output_layer", &output_layer), //
+ Flag("self_test", &self_test), //
+ Flag("root_dir", &root_dir)});
+ if (!parse_result) {
+ LOG(ERROR) << "Error parsing command-line flags.";
return -1;
}
- // Passthrough any extra flags.
- int dst = 1; // Skip argv[0]
-
- for (char* f : unknown_flags) {
- argv[dst++] = f;
- }
- argv[dst++] = nullptr;
- argc = unknown_flags.size() + 1;
-
// We need to call this to set up global state for TensorFlow.
tensorflow::port::InitMain(argv[0], &argc, &argv);
+ if (argc > 1) {
+ LOG(ERROR) << "Unknown argument " << argv[1];
+ return -1;
+ }
// First we load and initialize the model.
std::unique_ptr<tensorflow::Session> session;