diff options
author | 2018-01-03 11:00:21 -0800 | |
---|---|---|
committer | 2018-01-03 11:04:25 -0800 | |
commit | 613a160b55bf23d886e20f7512a79d57b9d2f83f (patch) | |
tree | 16d579355fffc85cecc9ee0e257c6aea30c0e69f /tensorflow/core/api_def/api_test.cc | |
parent | 0f2fa9daa6b36e7dcad0b739ef4d08944e69ecce (diff) |
Automated g4 rollback of changelist 180670333
PiperOrigin-RevId: 180691955
Diffstat (limited to 'tensorflow/core/api_def/api_test.cc')
-rw-r--r-- | tensorflow/core/api_def/api_test.cc | 406 |
1 files changed, 272 insertions, 134 deletions
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index d689bf0480..2cdc14843f 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Test that validates tensorflow/core/api_def/base_api/api_def*.pbtxt files. +// Test that verifies tensorflow/core/api_def/base_api/api_def*.pbtxt files +// are correct. If api_def*.pbtxt do not match expected contents, run +// tensorflow/core/api_def/base_api/update_api_def.sh script to update them. #include <ctype.h> #include <algorithm> @@ -42,173 +44,309 @@ namespace tensorflow { namespace { constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; +constexpr char kOverridesFilePath[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt"; constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -} // namespace -// Returns a list of ops excluded from ApiDef. -// TODO(annarev): figure out if we should keep ApiDefs for these ops as well. -const std::unordered_set<string>* GetExcludedOps() { - static std::unordered_set<string>* excluded_ops = - new std::unordered_set<string>( - {"BigQueryReader", "GenerateBigQueryReaderPartitions"}); - return excluded_ops; +void FillBaseApiDef(ApiDef* api_def, const OpDef& op) { + api_def->set_graph_op_name(op.name()); + // Add arg docs + for (auto& input_arg : op.input_arg()) { + if (!input_arg.description().empty()) { + auto* api_def_in_arg = api_def->add_in_arg(); + api_def_in_arg->set_name(input_arg.name()); + api_def_in_arg->set_description(input_arg.description()); + } + } + for (auto& output_arg : op.output_arg()) { + if (!output_arg.description().empty()) { + auto* api_def_out_arg = api_def->add_out_arg(); + api_def_out_arg->set_name(output_arg.name()); + api_def_out_arg->set_description(output_arg.description()); + } + } + // Add attr docs + for (auto& attr : op.attr()) { + if (!attr.description().empty()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + api_def_attr->set_description(attr.description()); + } + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); } -// Reads golden ApiDef files and returns a map from file name to ApiDef file -// contents. -void GetGoldenApiDefs(Env* env, const string& api_files_dir, - std::unordered_map<string, ApiDef>* name_to_api_def) { - std::vector<string> matching_paths; - TF_CHECK_OK(env->GetMatchingPaths( - io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); - - for (auto& file_path : matching_paths) { - string file_contents; - TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); - file_contents = PBTxtFromMultiline(file_contents); - - ApiDefs api_defs; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents, - &api_defs)) - << "Failed to load " << file_path; - CHECK_EQ(api_defs.op_size(), 1); - (*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0); +// Checks if arg1 should be before arg2 according to ordering in args. +bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) { + for (auto& arg : args) { + if (arg.name() == arg2->name()) { + return false; + } else if (arg.name() == arg1->name()) { + return true; + } } + return false; } -class ApiTest : public ::testing::Test { - protected: - ApiTest() { - OpRegistry::Global()->Export(false, &ops_); - const std::vector<string> multi_line_fields = {"description"}; +// Checks if attr1 should be before attr2 according to ordering in op_def. +bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2, + const OpDef& op_def) { + for (auto& attr : op_def.attr()) { + if (attr.name() == attr2->name()) { + return false; + } else if (attr.name() == attr1->name()) { + return true; + } + } + return false; +} - Env* env = Env::Default(); - GetGoldenApiDefs(env, kDefaultApiDefDir, &api_defs_map_); +// Applies renames to args. +void ApplyArgOverrides( + protobuf::RepeatedPtrField<ApiDef::Arg>* args, + const protobuf::RepeatedPtrField<OpGenOverride::Rename>& renames, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& op_args, + const string& op_name) { + for (auto& rename : renames) { + // First check if rename is valid. + bool valid = false; + for (const auto& op_arg : op_args) { + if (op_arg.name() == rename.from()) { + valid = true; + } + } + QCHECK(valid) << rename.from() << " is not a valid argument for " + << op_name; + bool found_arg = false; + // If Arg is already in ApiDef, just update it. + for (int i = 0; i < args->size(); ++i) { + auto* arg = args->Mutable(i); + if (arg->name() == rename.from()) { + arg->set_rename_to(rename.to()); + found_arg = true; + break; + } + } + if (!found_arg) { // not in ApiDef, add a new arg. + auto* new_arg = args->Add(); + new_arg->set_name(rename.from()); + new_arg->set_rename_to(rename.to()); + } } - OpList ops_; - std::unordered_map<string, ApiDef> api_defs_map_; -}; + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(args->pointer_begin(), args->pointer_end(), + [&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) { + return CheckArgBefore(arg1, arg2, op_args); + }); +} -// Check that all ops have an ApiDef. -TEST_F(ApiTest, AllOpsAreInApiDef) { - auto* excluded_ops = GetExcludedOps(); - for (const auto& op : ops_.op()) { - if (excluded_ops->find(op.name()) != excluded_ops->end()) { - continue; +// Returns existing attribute with the given name if such +// attribute exists. Otherwise, adds a new attribute and returns it. +ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) { + // If Attr is already in ApiDef, just update it. + for (int i = 0; i < api_def->attr_size(); ++i) { + auto* attr = api_def->mutable_attr(i); + if (attr->name() == attr_name) { + return attr; } - ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end()) - << op.name() << " op does not have api_def_*.pbtxt file. " - << "Please add api_def_" << op.name() << ".pbtxt file " - << "under tensorflow/core/api_def/base_api/ directory."; } + // Add a new Attr. + auto* new_attr = api_def->add_attr(); + new_attr->set_name(attr_name); + return new_attr; } -// Check that ApiDefs have a corresponding op. -TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) { - std::unordered_set<string> op_names; - for (const auto& op : ops_.op()) { - op_names.insert(op.name()); +// Applies renames and default values to attributes. +void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override, + const OpDef& op_def) { + for (auto& attr_rename : op_override.attr_rename()) { + auto* attr = FindOrAddAttr(api_def, attr_rename.from()); + attr->set_rename_to(attr_rename.to()); } - for (const auto& name_and_api_def : api_defs_map_) { - ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) - << name_and_api_def.first << " op has ApiDef but missing from ops. " - << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; + + for (auto& attr_default : op_override.attr_default()) { + auto* attr = FindOrAddAttr(api_def, attr_default.name()); + *(attr->mutable_default_value()) = attr_default.value(); } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(api_def->mutable_attr()->pointer_begin(), + api_def->mutable_attr()->pointer_end(), + [&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) { + return CheckAttrBefore(attr1, attr2, op_def); + }); } -string GetOpDefHasDocStringError(const string& op_name) { - return strings::Printf( - "OpDef for %s has a doc string. " - "Doc strings must be defined in ApiDef instead of OpDef. " - "Please, add summary and descriptions in api_def_%s" - ".pbtxt file instead", - op_name.c_str(), op_name.c_str()); +void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op, + const OpGenOverride& op_override) { + // Fill ApiDef with data based on op and op_override. + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + api_def->add_endpoint()->set_name(op_override.rename_to()); + } else if (!op_override.alias().empty()) { + api_def->add_endpoint()->set_name(op.name()); + } + + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + + ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(), + op.input_arg(), api_def->graph_op_name()); + ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(), + op.output_arg(), api_def->graph_op_name()); + ApplyAttrOverrides(api_def, op_override, op); } -// Check that OpDef's do not have descriptions and summaries. -// Descriptions and summaries must be in corresponding ApiDefs. -TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { - auto* excluded_ops = GetExcludedOps(); - for (const auto& op : ops_.op()) { - if (excluded_ops->find(op.name()) != excluded_ops->end()) { +// Get map from ApiDef file path to corresponding ApiDefs proto. +std::unordered_map<string, ApiDefs> GenerateApiDef( + const string& api_def_dir, const OpList& ops, + const OpGenOverrides& overrides) { + std::unordered_map<string, OpGenOverride> name_to_override; + for (const auto& op_override : overrides.op()) { + name_to_override[op_override.name()] = op_override; + } + + std::unordered_map<string, ApiDefs> api_defs_map; + + // These ops are included in OpList only if TF_NEED_GCP + // is set to true. So, we skip them for now so that this test passes + // whether TF_NEED_GCP is set or not. + const std::unordered_set<string> ops_to_exclude = { + "BigQueryReader", "GenerateBigQueryReaderPartitions"}; + for (const auto& op : ops.op()) { + CHECK(!op.name().empty()) + << "Encountered empty op name: %s" << op.DebugString(); + if (ops_to_exclude.find(op.name()) != ops_to_exclude.end()) { + LOG(INFO) << "Skipping " << op.name(); continue; } - ASSERT_TRUE(op.summary().empty()) << GetOpDefHasDocStringError(op.name()); - ASSERT_TRUE(op.description().empty()) - << GetOpDefHasDocStringError(op.name()); - for (const auto& arg : op.input_arg()) { - ASSERT_TRUE(arg.description().empty()) - << GetOpDefHasDocStringError(op.name()); - } - for (const auto& arg : op.output_arg()) { - ASSERT_TRUE(arg.description().empty()) - << GetOpDefHasDocStringError(op.name()); - } - for (const auto& attr : op.attr()) { - ASSERT_TRUE(attr.description().empty()) - << GetOpDefHasDocStringError(op.name()); + string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat); + file_path = strings::Printf(file_path.c_str(), op.name().c_str()); + ApiDef* api_def = api_defs_map[file_path].add_op(); + FillBaseApiDef(api_def, op); + + if (name_to_override.find(op.name()) != name_to_override.end()) { + ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]); } } + return api_defs_map; } -// Checks that input arg names in an ApiDef match input -// arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefInputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.in_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.input_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Input argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } +// Reads golden ApiDef files and returns a map from file name to ApiDef file +// contents. +std::unordered_map<string, string> GetGoldenApiDefs( + Env* env, const string& api_files_dir) { + std::vector<string> matching_paths; + TF_CHECK_OK(env->GetMatchingPaths( + io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); + + std::unordered_map<string, string> file_path_to_api_def; + for (auto& file_path : matching_paths) { + string file_contents; + TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); + file_path_to_api_def[file_path] = file_contents; } + return file_path_to_api_def; } -// Checks that output arg names in an ApiDef match output -// arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefOutputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.out_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.output_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Output argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); +void RunApiTest(bool update_api_def, const string& api_files_dir) { + // Read C++ overrides file + OpGenOverrides overrides; + Env* env = Env::Default(); + TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides)); + + // Read all ops + OpList ops; + OpRegistry::Global()->Export(false, &ops); + const std::vector<string> multi_line_fields = {"description"}; + + // Get expected ApiDefs + const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides); + + bool updated_at_least_one_file = false; + const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir); + + for (auto new_api_entry : new_api_defs_map) { + const auto& file_path = new_api_entry.first; + std::string golden_api_defs_str = ""; + if (golden_api_defs_map.find(file_path) != golden_api_defs_map.end()) { + golden_api_defs_str = golden_api_defs_map.at(file_path); + } + string new_api_defs_str = new_api_entry.second.DebugString(); + new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); + if (golden_api_defs_str == new_api_defs_str) { + continue; + } + if (update_api_def) { + std::cout << "Updating " << file_path << "..." << std::endl; + TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ(golden_api_defs_str, new_api_defs_str) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; } } -} -// Checks that attribute names in an ApiDef match attribute -// names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_attr : api_def.attr()) { - bool found_attr = false; - for (const auto& op_attr : op.attr()) { - if (api_def_attr.name() == op_attr.name()) { - found_attr = true; - } + for (const auto& golden_api_entry : golden_api_defs_map) { + const auto& file_path = golden_api_entry.first; + if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) { + if (update_api_def) { + std::cout << "Deleting " << file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(file_path)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ("", golden_api_entry.second) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; } - ASSERT_TRUE(found_attr) - << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" - << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); } } + + if (update_api_def && !updated_at_least_one_file) { + std::cout << "Api def files are already up to date." << std::endl; + } } + +TEST(ApiTest, GenerateBaseAPIDef) { RunApiTest(false, kDefaultApiDefDir); } +} // namespace } // namespace tensorflow + +int main(int argc, char** argv) { + bool update_api_def = false; + tensorflow::string api_files_dir = tensorflow::kDefaultApiDefDir; + std::vector<tensorflow::Flag> flag_list = { + tensorflow::Flag( + "update_api_def", &update_api_def, + "Whether to update tensorflow/core/api_def/base_api/api_def*.pbtxt " + "files if they differ from expected API."), + tensorflow::Flag("api_def_dir", &api_files_dir, + "Base directory of api_def*.pbtxt files.")}; + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + std::cerr << usage << std::endl; + return 2; + } + if (update_api_def) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::RunApiTest(update_api_def, api_files_dir); + return 0; + } + testing::InitGoogleTest(&argc, argv); + // Run tests + return RUN_ALL_TESTS(); +} |