aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/api_def/api_test.cc
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-01-03 11:00:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-03 11:04:25 -0800
commit613a160b55bf23d886e20f7512a79d57b9d2f83f (patch)
tree16d579355fffc85cecc9ee0e257c6aea30c0e69f /tensorflow/core/api_def/api_test.cc
parent0f2fa9daa6b36e7dcad0b739ef4d08944e69ecce (diff)
Automated g4 rollback of changelist 180670333
PiperOrigin-RevId: 180691955
Diffstat (limited to 'tensorflow/core/api_def/api_test.cc')
-rw-r--r--tensorflow/core/api_def/api_test.cc406
1 files changed, 272 insertions, 134 deletions
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc
index d689bf0480..2cdc14843f 100644
--- a/tensorflow/core/api_def/api_test.cc
+++ b/tensorflow/core/api_def/api_test.cc
@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Test that validates tensorflow/core/api_def/base_api/api_def*.pbtxt files.
+// Test that verifies tensorflow/core/api_def/base_api/api_def*.pbtxt files
+// are correct. If api_def*.pbtxt do not match expected contents, run
+// tensorflow/core/api_def/base_api/update_api_def.sh script to update them.
#include <ctype.h>
#include <algorithm>
@@ -42,173 +44,309 @@ namespace tensorflow {
namespace {
constexpr char kDefaultApiDefDir[] =
"tensorflow/core/api_def/base_api";
+constexpr char kOverridesFilePath[] =
+ "tensorflow/cc/ops/op_gen_overrides.pbtxt";
+constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt";
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
-} // namespace
-// Returns a list of ops excluded from ApiDef.
-// TODO(annarev): figure out if we should keep ApiDefs for these ops as well.
-const std::unordered_set<string>* GetExcludedOps() {
- static std::unordered_set<string>* excluded_ops =
- new std::unordered_set<string>(
- {"BigQueryReader", "GenerateBigQueryReaderPartitions"});
- return excluded_ops;
+void FillBaseApiDef(ApiDef* api_def, const OpDef& op) {
+ api_def->set_graph_op_name(op.name());
+ // Add arg docs
+ for (auto& input_arg : op.input_arg()) {
+ if (!input_arg.description().empty()) {
+ auto* api_def_in_arg = api_def->add_in_arg();
+ api_def_in_arg->set_name(input_arg.name());
+ api_def_in_arg->set_description(input_arg.description());
+ }
+ }
+ for (auto& output_arg : op.output_arg()) {
+ if (!output_arg.description().empty()) {
+ auto* api_def_out_arg = api_def->add_out_arg();
+ api_def_out_arg->set_name(output_arg.name());
+ api_def_out_arg->set_description(output_arg.description());
+ }
+ }
+ // Add attr docs
+ for (auto& attr : op.attr()) {
+ if (!attr.description().empty()) {
+ auto* api_def_attr = api_def->add_attr();
+ api_def_attr->set_name(attr.name());
+ api_def_attr->set_description(attr.description());
+ }
+ }
+ // Add docs
+ api_def->set_summary(op.summary());
+ api_def->set_description(op.description());
}
-// Reads golden ApiDef files and returns a map from file name to ApiDef file
-// contents.
-void GetGoldenApiDefs(Env* env, const string& api_files_dir,
- std::unordered_map<string, ApiDef>* name_to_api_def) {
- std::vector<string> matching_paths;
- TF_CHECK_OK(env->GetMatchingPaths(
- io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths));
-
- for (auto& file_path : matching_paths) {
- string file_contents;
- TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents));
- file_contents = PBTxtFromMultiline(file_contents);
-
- ApiDefs api_defs;
- CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents,
- &api_defs))
- << "Failed to load " << file_path;
- CHECK_EQ(api_defs.op_size(), 1);
- (*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0);
+// Checks if arg1 should be before arg2 according to ordering in args.
+bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args) {
+ for (auto& arg : args) {
+ if (arg.name() == arg2->name()) {
+ return false;
+ } else if (arg.name() == arg1->name()) {
+ return true;
+ }
}
+ return false;
}
-class ApiTest : public ::testing::Test {
- protected:
- ApiTest() {
- OpRegistry::Global()->Export(false, &ops_);
- const std::vector<string> multi_line_fields = {"description"};
+// Checks if attr1 should be before attr2 according to ordering in op_def.
+bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2,
+ const OpDef& op_def) {
+ for (auto& attr : op_def.attr()) {
+ if (attr.name() == attr2->name()) {
+ return false;
+ } else if (attr.name() == attr1->name()) {
+ return true;
+ }
+ }
+ return false;
+}
- Env* env = Env::Default();
- GetGoldenApiDefs(env, kDefaultApiDefDir, &api_defs_map_);
+// Applies renames to args.
+void ApplyArgOverrides(
+ protobuf::RepeatedPtrField<ApiDef::Arg>* args,
+ const protobuf::RepeatedPtrField<OpGenOverride::Rename>& renames,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& op_args,
+ const string& op_name) {
+ for (auto& rename : renames) {
+ // First check if rename is valid.
+ bool valid = false;
+ for (const auto& op_arg : op_args) {
+ if (op_arg.name() == rename.from()) {
+ valid = true;
+ }
+ }
+ QCHECK(valid) << rename.from() << " is not a valid argument for "
+ << op_name;
+ bool found_arg = false;
+ // If Arg is already in ApiDef, just update it.
+ for (int i = 0; i < args->size(); ++i) {
+ auto* arg = args->Mutable(i);
+ if (arg->name() == rename.from()) {
+ arg->set_rename_to(rename.to());
+ found_arg = true;
+ break;
+ }
+ }
+ if (!found_arg) { // not in ApiDef, add a new arg.
+ auto* new_arg = args->Add();
+ new_arg->set_name(rename.from());
+ new_arg->set_rename_to(rename.to());
+ }
}
- OpList ops_;
- std::unordered_map<string, ApiDef> api_defs_map_;
-};
+ // We don't really need a specific order here right now.
+ // However, it is clearer if order follows OpDef.
+ std::sort(args->pointer_begin(), args->pointer_end(),
+ [&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) {
+ return CheckArgBefore(arg1, arg2, op_args);
+ });
+}
-// Check that all ops have an ApiDef.
-TEST_F(ApiTest, AllOpsAreInApiDef) {
- auto* excluded_ops = GetExcludedOps();
- for (const auto& op : ops_.op()) {
- if (excluded_ops->find(op.name()) != excluded_ops->end()) {
- continue;
+// Returns existing attribute with the given name if such
+// attribute exists. Otherwise, adds a new attribute and returns it.
+ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) {
+ // If Attr is already in ApiDef, just update it.
+ for (int i = 0; i < api_def->attr_size(); ++i) {
+ auto* attr = api_def->mutable_attr(i);
+ if (attr->name() == attr_name) {
+ return attr;
}
- ASSERT_TRUE(api_defs_map_.find(op.name()) != api_defs_map_.end())
- << op.name() << " op does not have api_def_*.pbtxt file. "
- << "Please add api_def_" << op.name() << ".pbtxt file "
- << "under tensorflow/core/api_def/base_api/ directory.";
}
+ // Add a new Attr.
+ auto* new_attr = api_def->add_attr();
+ new_attr->set_name(attr_name);
+ return new_attr;
}
-// Check that ApiDefs have a corresponding op.
-TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) {
- std::unordered_set<string> op_names;
- for (const auto& op : ops_.op()) {
- op_names.insert(op.name());
+// Applies renames and default values to attributes.
+void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override,
+ const OpDef& op_def) {
+ for (auto& attr_rename : op_override.attr_rename()) {
+ auto* attr = FindOrAddAttr(api_def, attr_rename.from());
+ attr->set_rename_to(attr_rename.to());
}
- for (const auto& name_and_api_def : api_defs_map_) {
- ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end())
- << name_and_api_def.first << " op has ApiDef but missing from ops. "
- << "Does api_def_" << name_and_api_def.first << " need to be deleted?";
+
+ for (auto& attr_default : op_override.attr_default()) {
+ auto* attr = FindOrAddAttr(api_def, attr_default.name());
+ *(attr->mutable_default_value()) = attr_default.value();
}
+ // We don't really need a specific order here right now.
+ // However, it is clearer if order follows OpDef.
+ std::sort(api_def->mutable_attr()->pointer_begin(),
+ api_def->mutable_attr()->pointer_end(),
+ [&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) {
+ return CheckAttrBefore(attr1, attr2, op_def);
+ });
}
-string GetOpDefHasDocStringError(const string& op_name) {
- return strings::Printf(
- "OpDef for %s has a doc string. "
- "Doc strings must be defined in ApiDef instead of OpDef. "
- "Please, add summary and descriptions in api_def_%s"
- ".pbtxt file instead",
- op_name.c_str(), op_name.c_str());
+void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op,
+ const OpGenOverride& op_override) {
+ // Fill ApiDef with data based on op and op_override.
+ // Set visibility
+ if (op_override.skip()) {
+ api_def->set_visibility(ApiDef_Visibility_SKIP);
+ } else if (op_override.hide()) {
+ api_def->set_visibility(ApiDef_Visibility_HIDDEN);
+ }
+ // Add endpoints
+ if (!op_override.rename_to().empty()) {
+ api_def->add_endpoint()->set_name(op_override.rename_to());
+ } else if (!op_override.alias().empty()) {
+ api_def->add_endpoint()->set_name(op.name());
+ }
+
+ for (auto& alias : op_override.alias()) {
+ auto* endpoint = api_def->add_endpoint();
+ endpoint->set_name(alias);
+ }
+
+ ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(),
+ op.input_arg(), api_def->graph_op_name());
+ ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(),
+ op.output_arg(), api_def->graph_op_name());
+ ApplyAttrOverrides(api_def, op_override, op);
}
-// Check that OpDef's do not have descriptions and summaries.
-// Descriptions and summaries must be in corresponding ApiDefs.
-TEST_F(ApiTest, OpDefsShouldNotHaveDocs) {
- auto* excluded_ops = GetExcludedOps();
- for (const auto& op : ops_.op()) {
- if (excluded_ops->find(op.name()) != excluded_ops->end()) {
+// Get map from ApiDef file path to corresponding ApiDefs proto.
+std::unordered_map<string, ApiDefs> GenerateApiDef(
+ const string& api_def_dir, const OpList& ops,
+ const OpGenOverrides& overrides) {
+ std::unordered_map<string, OpGenOverride> name_to_override;
+ for (const auto& op_override : overrides.op()) {
+ name_to_override[op_override.name()] = op_override;
+ }
+
+ std::unordered_map<string, ApiDefs> api_defs_map;
+
+ // These ops are included in OpList only if TF_NEED_GCP
+ // is set to true. So, we skip them for now so that this test passes
+ // whether TF_NEED_GCP is set or not.
+ const std::unordered_set<string> ops_to_exclude = {
+ "BigQueryReader", "GenerateBigQueryReaderPartitions"};
+ for (const auto& op : ops.op()) {
+ CHECK(!op.name().empty())
+ << "Encountered empty op name: %s" << op.DebugString();
+ if (ops_to_exclude.find(op.name()) != ops_to_exclude.end()) {
+ LOG(INFO) << "Skipping " << op.name();
continue;
}
- ASSERT_TRUE(op.summary().empty()) << GetOpDefHasDocStringError(op.name());
- ASSERT_TRUE(op.description().empty())
- << GetOpDefHasDocStringError(op.name());
- for (const auto& arg : op.input_arg()) {
- ASSERT_TRUE(arg.description().empty())
- << GetOpDefHasDocStringError(op.name());
- }
- for (const auto& arg : op.output_arg()) {
- ASSERT_TRUE(arg.description().empty())
- << GetOpDefHasDocStringError(op.name());
- }
- for (const auto& attr : op.attr()) {
- ASSERT_TRUE(attr.description().empty())
- << GetOpDefHasDocStringError(op.name());
+ string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat);
+ file_path = strings::Printf(file_path.c_str(), op.name().c_str());
+ ApiDef* api_def = api_defs_map[file_path].add_op();
+ FillBaseApiDef(api_def, op);
+
+ if (name_to_override.find(op.name()) != name_to_override.end()) {
+ ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]);
}
}
+ return api_defs_map;
}
-// Checks that input arg names in an ApiDef match input
-// arg names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefInputArgsAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_arg : api_def.in_arg()) {
- bool found_arg = false;
- for (const auto& op_arg : op.input_arg()) {
- if (api_def_arg.name() == op_arg.name()) {
- found_arg = true;
- break;
- }
- }
- ASSERT_TRUE(found_arg)
- << "Input argument " << api_def_arg.name()
- << " (overwritten in api_def_" << op.name()
- << ".pbtxt) is not defined in OpDef for " << op.name();
- }
+// Reads golden ApiDef files and returns a map from file name to ApiDef file
+// contents.
+std::unordered_map<string, string> GetGoldenApiDefs(
+ Env* env, const string& api_files_dir) {
+ std::vector<string> matching_paths;
+ TF_CHECK_OK(env->GetMatchingPaths(
+ io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths));
+
+ std::unordered_map<string, string> file_path_to_api_def;
+ for (auto& file_path : matching_paths) {
+ string file_contents;
+ TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents));
+ file_path_to_api_def[file_path] = file_contents;
}
+ return file_path_to_api_def;
}
-// Checks that output arg names in an ApiDef match output
-// arg names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefOutputArgsAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_arg : api_def.out_arg()) {
- bool found_arg = false;
- for (const auto& op_arg : op.output_arg()) {
- if (api_def_arg.name() == op_arg.name()) {
- found_arg = true;
- break;
- }
- }
- ASSERT_TRUE(found_arg)
- << "Output argument " << api_def_arg.name()
- << " (overwritten in api_def_" << op.name()
- << ".pbtxt) is not defined in OpDef for " << op.name();
+void RunApiTest(bool update_api_def, const string& api_files_dir) {
+ // Read C++ overrides file
+ OpGenOverrides overrides;
+ Env* env = Env::Default();
+ TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides));
+
+ // Read all ops
+ OpList ops;
+ OpRegistry::Global()->Export(false, &ops);
+ const std::vector<string> multi_line_fields = {"description"};
+
+ // Get expected ApiDefs
+ const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides);
+
+ bool updated_at_least_one_file = false;
+ const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir);
+
+ for (auto new_api_entry : new_api_defs_map) {
+ const auto& file_path = new_api_entry.first;
+ std::string golden_api_defs_str = "";
+ if (golden_api_defs_map.find(file_path) != golden_api_defs_map.end()) {
+ golden_api_defs_str = golden_api_defs_map.at(file_path);
+ }
+ string new_api_defs_str = new_api_entry.second.DebugString();
+ new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields);
+ if (golden_api_defs_str == new_api_defs_str) {
+ continue;
+ }
+ if (update_api_def) {
+ std::cout << "Updating " << file_path << "..." << std::endl;
+ TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str));
+ updated_at_least_one_file = true;
+ } else {
+ EXPECT_EQ(golden_api_defs_str, new_api_defs_str)
+ << "To update golden API files, run "
+ << "tensorflow/core/api_def/update_api_def.sh.";
}
}
-}
-// Checks that attribute names in an ApiDef match attribute
-// names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_attr : api_def.attr()) {
- bool found_attr = false;
- for (const auto& op_attr : op.attr()) {
- if (api_def_attr.name() == op_attr.name()) {
- found_attr = true;
- }
+ for (const auto& golden_api_entry : golden_api_defs_map) {
+ const auto& file_path = golden_api_entry.first;
+ if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) {
+ if (update_api_def) {
+ std::cout << "Deleting " << file_path << "..." << std::endl;
+ TF_EXPECT_OK(env->DeleteFile(file_path));
+ updated_at_least_one_file = true;
+ } else {
+ EXPECT_EQ("", golden_api_entry.second)
+ << "To update golden API files, run "
+ << "tensorflow/core/api_def/update_api_def.sh.";
}
- ASSERT_TRUE(found_attr)
- << "Attribute " << api_def_attr.name() << " (overwritten in api_def_"
- << op.name() << ".pbtxt) is not defined in OpDef for " << op.name();
}
}
+
+ if (update_api_def && !updated_at_least_one_file) {
+ std::cout << "Api def files are already up to date." << std::endl;
+ }
}
+
+TEST(ApiTest, GenerateBaseAPIDef) { RunApiTest(false, kDefaultApiDefDir); }
+} // namespace
} // namespace tensorflow
+
+int main(int argc, char** argv) {
+ bool update_api_def = false;
+ tensorflow::string api_files_dir = tensorflow::kDefaultApiDefDir;
+ std::vector<tensorflow::Flag> flag_list = {
+ tensorflow::Flag(
+ "update_api_def", &update_api_def,
+ "Whether to update tensorflow/core/api_def/base_api/api_def*.pbtxt "
+ "files if they differ from expected API."),
+ tensorflow::Flag("api_def_dir", &api_files_dir,
+ "Base directory of api_def*.pbtxt files.")};
+ std::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
+ bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
+ if (!parsed_values_ok) {
+ std::cerr << usage << std::endl;
+ return 2;
+ }
+ if (update_api_def) {
+ tensorflow::port::InitMain(argv[0], &argc, &argv);
+ tensorflow::RunApiTest(update_api_def, api_files_dir);
+ return 0;
+ }
+ testing::InitGoogleTest(&argc, argv);
+ // Run tests
+ return RUN_ALL_TESTS();
+}