aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/lib/core/command_line_flags.cc
blob: 0f1072ffaa1c97d47eba287533cd13156568ade3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#include "tensorflow/core/lib/core/command_line_flags.h"

#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"

namespace tensorflow {
namespace {

// Templated function to convert a string to target values.
// Return true if the conversion is successful. Otherwise, return false.
template <typename T>
bool StringToValue(const string& content, T* value);

template <>
bool StringToValue<int32>(const string& content, int* value) {
  return str_util::NumericParse32(content, value);
}

// Parse a single argument by linearly searching through the command table.
// The input format is: --argument=value.
// Return OK if the argument is used. It store the extracted value into the
// matching flag.
// Return NOT_FOUND if the argument is not recognized.
// Retrun INVALID_ARGUMENT if the command is recognized, but fails to extract
// its value.
template <typename T>
Status ParseArgument(const string& argument) {
  for (auto& command :
       internal::CommandLineFlagRegistry<int>::Instance()->commands) {
    string prefix = strings::StrCat("--", command.name, "=");
    if (tensorflow::StringPiece(argument).starts_with(prefix)) {
      string content = argument.substr(prefix.length());
      if (StringToValue<T>(content, command.value)) {
        return Status::OK();
      }
      return Status(error::INVALID_ARGUMENT,
                    strings::StrCat("Cannot parse integer in: ", argument));
    }
  }
  return Status(error::NOT_FOUND,
                strings::StrCat("Unknown command: ", argument));
}

// A specialization for booleans. The input format is:
//   "--argument" or "--noargument".
// Parse a single argument by linearly searching through the command table.
// Return OK if the argument is used. The value is stored in the matching flag.
// Return NOT_FOUND if the argument is not recognized.
template <>
Status ParseArgument<bool>(const string& argument) {
  for (auto& command :
       internal::CommandLineFlagRegistry<bool>::Instance()->commands) {
    if (argument == strings::StrCat("--", command.name)) {
      *command.value = true;
      return Status::OK();
    } else if (argument == strings::StrCat("--no", command.name)) {
      *command.value = false;
      return Status::OK();
    }
  }
  return Status(error::NOT_FOUND,
                strings::StrCat("Unknown command: ", argument));
}
}  // namespace

Status ParseCommandLineFlags(int* argc, char* argv[]) {
  int unused_argc = 1;
  for (int index = 1; index < *argc; ++index) {
    Status s;
    // Search bool commands.
    s = ParseArgument<bool>(argv[index]);
    if (s.ok()) {
      continue;
    }
    if (s.code() != error::NOT_FOUND) {
      return s;
    }
    // Search int32 commands.
    s = ParseArgument<int32>(argv[index]);
    if (s.ok()) {
      continue;
    }
    if (s.code() != error::NOT_FOUND) {
      return s;
    }
    // Pointer swap the unused argument to the front.
    std::swap(argv[unused_argc++], argv[index]);
  }
  *argc = unused_argc;
  return Status::OK();
}

}  // namespace tensorflow