diff options
author | Anna R <annarev@google.com> | 2017-12-04 12:31:03 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-04 12:37:12 -0800 |
commit | 8f1e63d5629bda4f6c91fdec7a3b8418ed96786e (patch) | |
tree | 0f56fddabcd4e5cf91090acf5152cc54b3651e6a /tensorflow/python/framework/python_op_gen_main.cc | |
parent | a1c29139ccf441ad4de97c4e7fe2729e6130fcb8 (diff) |
Actually use ApiDef when generating Python API.
PiperOrigin-RevId: 177851421
Diffstat (limited to 'tensorflow/python/framework/python_op_gen_main.cc')
-rw-r--r-- | tensorflow/python/framework/python_op_gen_main.cc | 56 |
1 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index 61b1d02a5e..bc5ca195da 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -34,12 +34,6 @@ limitations under the License. namespace tensorflow { namespace { -constexpr char kBaseApiDef[] = - "tensorflow/core/api_def/base_api/*.pbtxt"; -constexpr char kPythonApiDef[] = - "tensorflow/core/api_def/python_api/*.pbtxt"; -constexpr bool kUseApiDef = false; - Status ReadOpListFromFile(const string& filename, std::vector<string>* op_list) { std::unique_ptr<RandomAccessFile> file; @@ -110,22 +104,23 @@ string InferSourceFileName(const char* argv_zero) { } void PrintAllPythonOps(const std::vector<string>& op_list, + const std::vector<string>& api_def_dirs, const string& source_file_name, bool require_shapes, bool op_list_is_whitelist) { OpList ops; OpRegistry::Global()->Export(false, &ops); ApiDefMap api_def_map(ops); - if (kUseApiDef) { + if (!api_def_dirs.empty()) { Env* env = Env::Default(); - std::vector<string> base_api_files; - std::vector<string> python_api_files; - TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files)); - TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files)); - - TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files)); - TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files)); + for (const auto& api_def_dir : api_def_dirs) { + std::vector<string> api_files; + TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"), + &api_files)); + TF_CHECK_OK(api_def_map.LoadFileList(env, api_files)); + } + api_def_map.UpdateDocs(); } if (op_list_is_whitelist) { @@ -154,23 +149,30 @@ int main(int argc, char* argv[]) { tensorflow::InferSourceFileName(argv[0]); // Usage: - // gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1] - if (argc == 2) { - tensorflow::PrintAllPythonOps({}, source_file_name, - tensorflow::string(argv[1]) == "1", - false /* op_list_is_whitelist */); - } else if (argc == 3) { - std::vector<tensorflow::string> hidden_ops; - TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops)); - tensorflow::PrintAllPythonOps(hidden_ops, source_file_name, + // gen_main api_def_dir1,api_def_dir2,... + // [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1] + if (argc < 3) { + return -1; + } + std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split( + argv[1], ",", tensorflow::str_util::SkipEmpty()); + + if (argc == 3) { + tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name, tensorflow::string(argv[2]) == "1", false /* op_list_is_whitelist */); } else if (argc == 4) { + std::vector<tensorflow::string> hidden_ops; + TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops)); + tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name, + tensorflow::string(argv[3]) == "1", + false /* op_list_is_whitelist */); + } else if (argc == 5) { std::vector<tensorflow::string> op_list; - TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list)); - tensorflow::PrintAllPythonOps(op_list, source_file_name, - tensorflow::string(argv[2]) == "1", - tensorflow::string(argv[3]) == "1"); + TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list)); + tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name, + tensorflow::string(argv[3]) == "1", + tensorflow::string(argv[4]) == "1"); } else { return -1; } |