aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/python_op_gen_main.cc
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-12-04 12:31:03 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-04 12:37:12 -0800
commit8f1e63d5629bda4f6c91fdec7a3b8418ed96786e (patch)
tree0f56fddabcd4e5cf91090acf5152cc54b3651e6a /tensorflow/python/framework/python_op_gen_main.cc
parenta1c29139ccf441ad4de97c4e7fe2729e6130fcb8 (diff)
Actually use ApiDef when generating Python API.
PiperOrigin-RevId: 177851421
Diffstat (limited to 'tensorflow/python/framework/python_op_gen_main.cc')
-rw-r--r--tensorflow/python/framework/python_op_gen_main.cc56
1 files changed, 29 insertions, 27 deletions
diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc
index 61b1d02a5e..bc5ca195da 100644
--- a/tensorflow/python/framework/python_op_gen_main.cc
+++ b/tensorflow/python/framework/python_op_gen_main.cc
@@ -34,12 +34,6 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr char kBaseApiDef[] =
- "tensorflow/core/api_def/base_api/*.pbtxt";
-constexpr char kPythonApiDef[] =
- "tensorflow/core/api_def/python_api/*.pbtxt";
-constexpr bool kUseApiDef = false;
-
Status ReadOpListFromFile(const string& filename,
std::vector<string>* op_list) {
std::unique_ptr<RandomAccessFile> file;
@@ -110,22 +104,23 @@ string InferSourceFileName(const char* argv_zero) {
}
void PrintAllPythonOps(const std::vector<string>& op_list,
+ const std::vector<string>& api_def_dirs,
const string& source_file_name, bool require_shapes,
bool op_list_is_whitelist) {
OpList ops;
OpRegistry::Global()->Export(false, &ops);
ApiDefMap api_def_map(ops);
- if (kUseApiDef) {
+ if (!api_def_dirs.empty()) {
Env* env = Env::Default();
- std::vector<string> base_api_files;
- std::vector<string> python_api_files;
- TF_CHECK_OK(env->GetMatchingPaths(kBaseApiDef, &base_api_files));
- TF_CHECK_OK(env->GetMatchingPaths(kPythonApiDef, &python_api_files));
-
- TF_CHECK_OK(api_def_map.LoadFileList(env, base_api_files));
- TF_CHECK_OK(api_def_map.LoadFileList(env, python_api_files));
+ for (const auto& api_def_dir : api_def_dirs) {
+ std::vector<string> api_files;
+ TF_CHECK_OK(env->GetMatchingPaths(io::JoinPath(api_def_dir, "*.pbtxt"),
+ &api_files));
+ TF_CHECK_OK(api_def_map.LoadFileList(env, api_files));
+ }
+ api_def_map.UpdateDocs();
}
if (op_list_is_whitelist) {
@@ -154,23 +149,30 @@ int main(int argc, char* argv[]) {
tensorflow::InferSourceFileName(argv[0]);
// Usage:
- // gen_main [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
- if (argc == 2) {
- tensorflow::PrintAllPythonOps({}, source_file_name,
- tensorflow::string(argv[1]) == "1",
- false /* op_list_is_whitelist */);
- } else if (argc == 3) {
- std::vector<tensorflow::string> hidden_ops;
- TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &hidden_ops));
- tensorflow::PrintAllPythonOps(hidden_ops, source_file_name,
+ // gen_main api_def_dir1,api_def_dir2,...
+ // [ @FILENAME | OpName[,OpName]* ] (0 | 1) [0 | 1]
+ if (argc < 3) {
+ return -1;
+ }
+ std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
+ argv[1], ",", tensorflow::str_util::SkipEmpty());
+
+ if (argc == 3) {
+ tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name,
tensorflow::string(argv[2]) == "1",
false /* op_list_is_whitelist */);
} else if (argc == 4) {
+ std::vector<tensorflow::string> hidden_ops;
+ TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops));
+ tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name,
+ tensorflow::string(argv[3]) == "1",
+ false /* op_list_is_whitelist */);
+ } else if (argc == 5) {
std::vector<tensorflow::string> op_list;
- TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[1], &op_list));
- tensorflow::PrintAllPythonOps(op_list, source_file_name,
- tensorflow::string(argv[2]) == "1",
- tensorflow::string(argv[3]) == "1");
+ TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &op_list));
+ tensorflow::PrintAllPythonOps(op_list, api_def_dirs, source_file_name,
+ tensorflow::string(argv[3]) == "1",
+ tensorflow::string(argv[4]) == "1");
} else {
return -1;
}