aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/framework
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2017-11-17 15:20:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-17 15:27:48 -0800
commitcb12ebe044ad8fb8515bc9d95d27c0ab19ec314b (patch)
treebfb473429b96efca65ac39a7b0bc0476c7ce1551 /tensorflow/cc/framework
parent3cc43816cda27c497399bf94429b174db5ed6d6b (diff)
Provide an option to use ApiDef instead of OpGenOverrides when generating C++ API. Also, updating UpdateDocs method to ApiDef to replace names in docs.
PiperOrigin-RevId: 176167953
Diffstat (limited to 'tensorflow/cc/framework')
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc218
-rw-r--r--tensorflow/cc/framework/cc_op_gen.h6
-rw-r--r--tensorflow/cc/framework/cc_op_gen_main.cc38
-rw-r--r--tensorflow/cc/framework/cc_op_gen_test.cc195
4 files changed, 357 insertions, 100 deletions
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index 38a17598b8..6f2b7acb82 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include <vector>
#include "tensorflow/cc/framework/cc_op_gen.h"
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op_def_util.h"
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/framework/op_gen_overrides.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
@@ -385,10 +387,10 @@ bool ArgIsList(const OpDef::ArgDef& arg) {
}
bool HasOptionalAttrs(
- const OpDef& op_def,
+ const ApiDef& api_def,
const std::unordered_map<string, string>& inferred_input_attrs) {
- for (int i = 0; i < op_def.attr_size(); ++i) {
- const auto& attr(op_def.attr(i));
+ for (int i = 0; i < api_def.attr_size(); ++i) {
+ const auto& attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) ==
inferred_input_attrs.end()) &&
attr.has_default_value()) {
@@ -398,12 +400,21 @@ bool HasOptionalAttrs(
return false;
}
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
struct OpInfo {
// graph_op_def: The OpDef used by the runtime, has the names that
// must be used when calling NodeBuilder.
// interface_op_def: The OpDef used in the interface in the generated
// code, with possibly overridden names and defaults.
- explicit OpInfo(const OpDef& graph_op_def, const OpDef& inteface_op_def,
+ explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases);
string GetOpAttrStruct() const;
string GetConstructorDecl(StringPiece op_name_prefix,
@@ -423,74 +434,81 @@ struct OpInfo {
string comment;
const OpDef& graph_op_def;
- const OpDef& op_def;
+ const ApiDef& api_def;
const std::vector<string>& aliases;
+ // Map from type attribute to corresponding original argument name.
std::unordered_map<string, string> inferred_input_attrs;
};
-OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
- const std::vector<string>& a)
- : graph_op_def(g_op_def), op_def(i_op_def), aliases(a) {
- op_name = op_def.name();
- InferOpAttributes(op_def, &inferred_input_attrs);
- has_optional_attrs = HasOptionalAttrs(op_def, inferred_input_attrs);
+OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def,
+ const std::vector<string>& aliases)
+ : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) {
+ op_name = api_def.endpoint(0).name();
+ InferOpAttributes(graph_op_def, &inferred_input_attrs);
+ has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs);
arg_types.push_back("const ::tensorflow::Scope&");
arg_names.push_back("scope");
- if (op_def.has_deprecation()) {
- if (!op_def.summary().empty()) {
- comment = strings::StrCat(op_def.summary(), "\n");
+ if (graph_op_def.has_deprecation()) {
+ if (!api_def.summary().empty()) {
+ comment = strings::StrCat(api_def.summary(), "\n");
}
strings::StrAppend(&comment, "DEPRECATED at GraphDef version ",
- op_def.deprecation().version(), ":\n",
- op_def.deprecation().explanation(), ".\n");
- } else if (op_def.summary().empty()) {
+ graph_op_def.deprecation().version(), ":\n",
+ graph_op_def.deprecation().explanation(), ".\n");
+ } else if (api_def.summary().empty()) {
comment = "TODO: add doc.\n";
} else {
- comment = strings::StrCat(op_def.summary(), "\n");
+ comment = strings::StrCat(api_def.summary(), "\n");
}
- if (!op_def.description().empty()) {
- strings::StrAppend(&comment, "\n", op_def.description(), "\n");
+ if (!api_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", api_def.description(), "\n");
}
strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n");
// Process inputs
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const auto& arg(op_def.input_arg(i));
+ for (int i = 0; i < api_def.arg_order_size(); ++i) {
+ const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def);
+ const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def);
arg_types.push_back(strings::StrCat(
"::tensorflow::", ArgIsList(arg) ? "InputList" : "Input"));
- arg_names.push_back(AvoidCPPKeywords(arg.name()));
+ arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
// TODO(keveman): Include input type information.
- StringPiece description = arg.description();
+ StringPiece description = api_def_arg.description();
if (!description.empty()) {
ConsumeEquals(&description);
- strings::StrAppend(&comment, "* ", AvoidCPPKeywords(arg.name()), ": ",
- arg.description(), "\n");
+ strings::StrAppend(&comment, "* ",
+ AvoidCPPKeywords(api_def_arg.rename_to()), ": ",
+ api_def_arg.description(), "\n");
}
}
// Process attrs
string required_attrs_comment;
string optional_attrs_comment;
- for (int i = 0; i < op_def.attr_size(); ++i) {
- const auto& attr(op_def.attr(i));
+ for (int i = 0; i < graph_op_def.attr_size(); ++i) {
+ // ApiDef attributes must be in the same order as in OpDef since
+ // we initialize ApiDef based on OpDef.
+ const auto& attr(graph_op_def.attr(i));
+ const auto& api_def_attr(api_def.attr(i));
+ CHECK_EQ(attr.name(), api_def_attr.name());
// Skip inferred arguments
if (inferred_input_attrs.count(attr.name()) > 0) continue;
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
- string attr_name = AvoidCPPKeywords(attr.name());
+ string attr_name = AvoidCPPKeywords(api_def_attr.rename_to());
string attr_comment;
- if (!attr.description().empty()) {
+ if (!api_def_attr.description().empty()) {
// TODO(keveman): Word wrap and indent this, to handle multi-line
// descriptions.
strings::StrAppend(&attr_comment, "* ", attr_name, ": ",
- attr.description(), "\n");
+ api_def_attr.description(), "\n");
}
- if (attr.has_default_value()) {
+ if (api_def_attr.has_default_value()) {
strings::StrAppend(&optional_attrs_comment, attr_comment);
} else {
strings::StrAppend(&required_attrs_comment, attr_comment);
@@ -508,44 +526,49 @@ OpInfo::OpInfo(const OpDef& g_op_def, const OpDef& i_op_def,
}
// Process outputs
- for (int i = 0; i < op_def.output_arg_size(); ++i) {
- const auto& arg = op_def.output_arg(i);
+ for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
+ // ApiDef arguments must be in the same order as in OpDef since
+ // we initialize ApiDef based on OpDef.
+ const auto& arg = graph_op_def.output_arg(i);
+ const auto& api_def_arg(api_def.out_arg(i));
+ CHECK_EQ(arg.name(), api_def_arg.name());
+
bool is_list = ArgIsList(arg);
output_types.push_back(
strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output"));
- output_names.push_back(AvoidCPPKeywords(arg.name()));
+ output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to()));
is_list_output.push_back(is_list);
}
strings::StrAppend(&comment, "\nReturns:\n");
- if (op_def.output_arg_size() == 0) { // No outputs.
+ if (graph_op_def.output_arg_size() == 0) { // No outputs.
strings::StrAppend(&comment, "* the created `Operation`\n");
- } else if (op_def.output_arg_size() == 1) { // One output
+ } else if (graph_op_def.output_arg_size() == 1) { // One output
if (is_list_output[0]) {
strings::StrAppend(&comment, "* `OutputList`: ");
} else {
strings::StrAppend(&comment, "* `Output`: ");
}
- if (op_def.output_arg(0).description().empty()) {
- strings::StrAppend(&comment, "The ", op_def.output_arg(0).name(),
+ if (api_def.out_arg(0).description().empty()) {
+ strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(),
" tensor.\n");
} else {
// TODO(josh11b): Word wrap this.
- strings::StrAppend(&comment, op_def.output_arg(0).description(), "\n");
+ strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n");
}
} else { // Multiple outputs.
- for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
if (is_list_output[i]) {
strings::StrAppend(&comment, "* `OutputList`");
} else {
strings::StrAppend(&comment, "* `Output`");
}
strings::StrAppend(&comment, " ", output_names[i]);
- if (op_def.output_arg(i).description().empty()) {
+ if (api_def.out_arg(i).description().empty()) {
strings::StrAppend(&comment, "\n");
} else {
// TODO(josh11b): Word wrap this.
- strings::StrAppend(&comment, ": ", op_def.output_arg(i).description(),
+ strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(),
"\n");
}
}
@@ -564,19 +587,20 @@ string OpInfo::GetOpAttrStruct() const {
string struct_fields;
string setters;
- for (int i = 0; i < op_def.attr_size(); ++i) {
- const auto& attr(op_def.attr(i));
+ for (int i = 0; i < graph_op_def.attr_size(); ++i) {
+ const auto& attr(graph_op_def.attr(i));
+ const auto& api_def_attr(api_def.attr(i));
// If attr will be inferred or it doesn't have a default value, don't
// add it to the struct.
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
- !attr.has_default_value()) {
+ !api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
- const string camel_case_name = ToCamelCase(attr.name());
+ const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def =
@@ -584,22 +608,25 @@ string OpInfo::GetOpAttrStruct() const {
attr_type_name, use_const ? "&" : "");
string attr_comment;
- if (!attr.description().empty()) {
- strings::StrAppend(&attr_comment, attr.description(), "\n\n");
+ if (!api_def_attr.description().empty()) {
+ strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n");
}
strings::StrAppend(&attr_comment, "Defaults to ",
- SummarizeAttrValue(attr.default_value()), "\n");
+ SummarizeAttrValue(api_def_attr.default_value()), "\n");
attr_comment = MakeComment(attr_comment, " ");
strings::StrAppend(&setters, attr_comment);
strings::StrAppend(&setters, " Attrs ", attr_func_def, " x) {\n");
strings::StrAppend(&setters, " Attrs ret = *this;\n");
- strings::StrAppend(&setters, " ret.", attr.name(), "_ = x;\n");
+ strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(),
+ "_ = x;\n");
strings::StrAppend(&setters, " return ret;\n }\n\n");
strings::StrAppend(
- &struct_fields, " ", attr_type_name, " ", attr.name(), "_ = ",
- PrintAttrValue(op_def.name(), attr.default_value()), ";\n");
+ &struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(),
+ "_ = ",
+ PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()),
+ ";\n");
}
if (struct_fields.empty()) {
@@ -676,17 +703,18 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
// Add the static functions to set optional attrs
if (has_optional_attrs) {
strings::StrAppend(&class_decl, "\n");
- for (int i = 0; i < op_def.attr_size(); ++i) {
- const auto& attr(op_def.attr(i));
+ for (int i = 0; i < graph_op_def.attr_size(); ++i) {
+ const auto& attr(graph_op_def.attr(i));
+ const auto& api_def_attr(api_def.attr(i));
if ((inferred_input_attrs.find(attr.name()) !=
inferred_input_attrs.end()) ||
- !attr.has_default_value()) {
+ !api_def_attr.has_default_value()) {
continue;
}
const auto entry = AttrTypeName(attr.type());
const auto attr_type_name = entry.first;
const bool use_const = entry.second;
- const string camel_case_name = ToCamelCase(attr.name());
+ const string camel_case_name = ToCamelCase(api_def_attr.rename_to());
const string suffix =
(camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : "";
const string attr_func_def = strings::StrCat(
@@ -726,11 +754,11 @@ void OpInfo::GetOutput(string* out) const {
strings::StrCat("if (!", scope_str, ".ok()) return;");
// No outputs.
- if (op_def.output_arg_size() == 0) {
+ if (graph_op_def.output_arg_size() == 0) {
strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
return;
}
- if (op_def.output_arg_size() == 1) {
+ if (graph_op_def.output_arg_size() == 1) {
// One output, no need for NameRangeMap
if (is_list_output[0]) {
strings::StrAppend(out,
@@ -752,7 +780,7 @@ void OpInfo::GetOutput(string* out) const {
".UpdateStatus(_status_);\n", " return;\n");
strings::StrAppend(out, " }\n\n");
- for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ for (int i = 0; i < graph_op_def.output_arg_size(); ++i) {
const string arg_range = strings::StrCat(
"_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]");
if (is_list_output[i]) {
@@ -776,11 +804,13 @@ string OpInfo::GetConstructorBody() const {
strings::StrAppend(&body, " ", return_on_error, "\n");
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const auto& arg(op_def.input_arg(i));
- strings::StrAppend(&body, " auto _", arg.name(), " = ::tensorflow::ops::",
- ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(",
- scope_str, ", ", AvoidCPPKeywords(arg.name()), ");\n");
+ for (int i = 0; i < graph_op_def.input_arg_size(); ++i) {
+ const auto& arg(graph_op_def.input_arg(i));
+ const auto& api_def_arg(api_def.in_arg(i));
+ strings::StrAppend(
+ &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::",
+ ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ",
+ AvoidCPPKeywords(api_def_arg.rename_to()), ");\n");
strings::StrAppend(&body, " ", return_on_error, "\n");
}
@@ -791,19 +821,21 @@ string OpInfo::GetConstructorBody() const {
&body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"",
graph_op_def.name(), "\")\n");
const string spaces = " ";
- for (int i = 0; i < op_def.input_arg_size(); ++i) {
- const auto& arg(op_def.input_arg(i));
- strings::StrAppend(&body, spaces, ".Input(_", arg.name(), ")\n");
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ const auto& arg(api_def.in_arg(i));
+ strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n");
}
- for (int i = 0; i < op_def.attr_size(); ++i) {
+ for (int i = 0; i < api_def.attr_size(); ++i) {
const auto& graph_attr(graph_op_def.attr(i));
- const auto& attr(op_def.attr(i));
- if (inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) {
+ const auto& api_def_attr(api_def.attr(i));
+ if (inferred_input_attrs.find(api_def_attr.name()) !=
+ inferred_input_attrs.end()) {
continue;
}
- const string attr_name = attr.has_default_value()
- ? strings::StrCat("attrs.", attr.name(), "_")
- : AvoidCPPKeywords(attr.name());
+ const string attr_name =
+ api_def_attr.has_default_value()
+ ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_")
+ : AvoidCPPKeywords(api_def_attr.rename_to());
strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ",
attr_name, ")\n");
}
@@ -845,10 +877,10 @@ void OpInfo::WriteClassDef(WritableFile* cc) const {
TF_CHECK_OK(cc->Append(class_def));
}
-void WriteCCOp(const OpDef& graph_op_def, const OpDef& interface_op_def,
+void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def,
const std::vector<string>& aliases, WritableFile* h,
WritableFile* cc) {
- OpInfo op_info(graph_op_def, interface_op_def, aliases);
+ OpInfo op_info(graph_op_def, api_def, aliases);
op_info.WriteClassDecl(h);
op_info.WriteClassDef(cc);
@@ -943,8 +975,9 @@ string MakeInternal(const string& fname) {
} // namespace
-void WriteCCOps(const OpList& ops, const string& dot_h_fname,
- const string& dot_cc_fname, const string& overrides_fnames) {
+void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& dot_h_fname, const string& dot_cc_fname,
+ const string& overrides_fnames) {
Env* env = Env::Default();
// Load the override map.
@@ -984,24 +1017,23 @@ void WriteCCOps(const OpList& ops, const string& dot_h_fname,
// code depends on it.
if (graph_op_def.name() == "Const") continue;
- // Incorporate overrides from override_map.
- OpDef interface_op_def = graph_op_def;
- const OpGenOverride* op_override =
- override_map.ApplyOverride(&interface_op_def);
+ const auto* api_def = api_def_map.GetApiDef(graph_op_def.name());
+
std::vector<string> aliases;
- if (op_override) {
- if (op_override->skip()) continue;
- aliases.assign(op_override->alias().begin(), op_override->alias().end());
- if (op_override->hide()) {
- // Write hidden ops to _internal.h and _internal.cc.
- WriteCCOp(graph_op_def, interface_op_def, aliases, internal_h.get(),
- internal_cc.get());
- continue;
- }
+ if (api_def->visibility() == ApiDef::SKIP) continue;
+ // First endpoint is canonical, the rest are aliases.
+ for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size();
+ ++endpoint_i) {
+ aliases.push_back(api_def->endpoint(endpoint_i).name());
+ }
+ if (api_def->visibility() == ApiDef::HIDDEN) {
+ // Write hidden ops to _internal.h and _internal.cc.
+ WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(),
+ internal_cc.get());
+ continue;
}
-
// This isn't a hidden op, write it to the main files.
- WriteCCOp(graph_op_def, interface_op_def, aliases, h.get(), cc.get());
+ WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get());
}
FinishFiles(false, h.get(), cc.get(), op_header_guard);
diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h
index fa5e004f03..cea2899014 100644
--- a/tensorflow/cc/framework/cc_op_gen.h
+++ b/tensorflow/cc/framework/cc_op_gen.h
@@ -17,13 +17,15 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
/// Result is written to files dot_h and dot_cc.
-void WriteCCOps(const OpList& ops, const string& dot_h_fname,
- const string& dot_cc_fname, const string& overrides_fnames);
+void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map,
+ const string& dot_h_fname, const string& dot_cc_fname,
+ const string& overrides_fnames);
} // namespace tensorflow
diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc
index 3b80cf993e..326d5668b8 100644
--- a/tensorflow/cc/framework/cc_op_gen_main.cc
+++ b/tensorflow/cc/framework/cc_op_gen_main.cc
@@ -16,7 +16,11 @@ limitations under the License.
#include "tensorflow/cc/framework/cc_op_gen.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/types.h"
@@ -24,10 +28,28 @@ namespace tensorflow {
namespace {
void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
- const std::string& overrides_fnames, bool include_internal) {
+ const std::string& overrides_fnames, bool include_internal,
+ const std::vector<string>& api_def_dirs) {
OpList ops;
OpRegistry::Global()->Export(include_internal, &ops);
- WriteCCOps(ops, dot_h, dot_cc, overrides_fnames);
+ ApiDefMap api_def_map(ops);
+ if (!api_def_dirs.empty()) {
+ Env* env = Env::Default();
+ // Only load files that correspond to "ops".
+ for (const auto& op : ops.op()) {
+ for (const auto& api_def_dir : api_def_dirs) {
+ const std::string api_def_file_pattern =
+ io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
+ if (env->FileExists(api_def_file_pattern).ok()) {
+ TF_CHECK_OK(api_def_map.LoadFile(env, api_def_file_pattern));
+ }
+ }
+ }
+ }
+
+ api_def_map.UpdateDocs();
+
+ WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames);
}
} // namespace
@@ -35,18 +57,24 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc,
int main(int argc, char* argv[]) {
tensorflow::port::InitMain(argv[0], &argc, &argv);
- if (argc != 5) {
+ // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt
+ // as an argument.
+ if (argc != 6) {
for (int i = 1; i < argc; ++i) {
fprintf(stderr, "Arg %d = %s\n", i, argv[i]);
}
fprintf(stderr,
- "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal\n"
+ "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal "
+ "api_def_dirs1,api_def_dir2 ...\n"
" include_internal: 1 means include internal ops\n",
argv[0]);
exit(1);
}
bool include_internal = tensorflow::StringPiece("1") == argv[4];
- tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal);
+ std::vector<tensorflow::string> api_def_dirs = tensorflow::str_util::Split(
+ argv[5], ",", tensorflow::str_util::SkipEmpty());
+ tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal,
+ api_def_dirs);
return 0;
}
diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc
new file mode 100644
index 0000000000..0b7e720a5c
--- /dev/null
+++ b/tensorflow/cc/framework/cc_op_gen_test.cc
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/framework/cc_op_gen.h"
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+// TODO(annarev): Remove this op_gen_overrides.pbtxt reference.
+// It is needed only because WriteCCOps takes it as an argument.
+constexpr char kOverridesFnames[] =
+ "tensorflow/cc/ops/op_gen_overrides.pbtxt";
+constexpr char kBaseOpDef[] = R"(
+op {
+ name: "Foo"
+ input_arg {
+ name: "images"
+ description: "Images to process."
+ }
+ input_arg {
+ name: "dim"
+ description: "Description for dim."
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "output"
+ description: "Description for output."
+ type: DT_FLOAT
+ }
+ attr {
+ name: "T"
+ type: "type"
+ description: "Type for images"
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_INT8
+ }
+ }
+ default_value {
+ i: 1
+ }
+ }
+ summary: "Summary for op Foo."
+ description: "Description for op Foo."
+}
+)";
+
+void ExpectHasSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_TRUE(s.contains(expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+}
+
+void ExpectDoesNotHaveSubstr(StringPiece s, StringPiece expected) {
+ EXPECT_FALSE(s.contains(expected))
+ << "'" << s << "' contains '" << expected << "'";
+}
+
+void ExpectSubstrOrder(const string& s, const string& before,
+ const string& after) {
+ int before_pos = s.find(before);
+ int after_pos = s.find(after);
+ ASSERT_NE(std::string::npos, before_pos);
+ ASSERT_NE(std::string::npos, after_pos);
+ EXPECT_LT(before_pos, after_pos)
+ << before << " is not before " << after << " in " << s;
+}
+
+// Runs WriteCCOps and stores output in (internal_)cc_file_path and
+// (internal_)h_file_path.
+void GenerateCcOpFiles(Env* env, const OpList& ops,
+ const ApiDefMap& api_def_map, string* h_file_text,
+ string* internal_h_file_text) {
+ const string& tmpdir = testing::TmpDir();
+
+ const auto h_file_path = io::JoinPath(tmpdir, "test.h");
+ const auto cc_file_path = io::JoinPath(tmpdir, "test.cc");
+ const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h");
+ const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc");
+
+ WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames);
+
+ TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text));
+ TF_ASSERT_OK(
+ ReadFileToString(env, internal_h_file_path, internal_h_file_text));
+}
+
+TEST(CcOpGenTest, TestVisibilityChangedToHidden) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ visibility: HIDDEN
+}
+)";
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
+ ApiDefMap api_def_map(op_defs);
+
+ string h_file_text, internal_h_file_text;
+ // Without ApiDef
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectHasSubstr(h_file_text, "class Foo");
+ ExpectDoesNotHaveSubstr(internal_h_file_text, "class Foo");
+
+ // With ApiDef
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectHasSubstr(internal_h_file_text, "class Foo");
+ ExpectDoesNotHaveSubstr(h_file_text, "class Foo");
+}
+
+TEST(CcOpGenTest, TestArgNameChanges) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ arg_order: "dim"
+ arg_order: "images"
+}
+)";
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
+
+ ApiDefMap api_def_map(op_defs);
+ string cc_file_text, h_file_text;
+ string internal_cc_file_text, internal_h_file_text;
+ // Without ApiDef
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectSubstrOrder(h_file_text, "Input images", "Input dim");
+
+ // With ApiDef
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectSubstrOrder(h_file_text, "Input dim", "Input images");
+}
+
+TEST(CcOpGenTest, TestEndpoints) {
+ const string api_def = R"(
+op {
+ graph_op_name: "Foo"
+ endpoint {
+ name: "Foo1"
+ }
+ endpoint {
+ name: "Foo2"
+ }
+}
+)";
+ Env* env = Env::Default();
+ OpList op_defs;
+ protobuf::TextFormat::ParseFromString(kBaseOpDef, &op_defs); // NOLINT
+
+ ApiDefMap api_def_map(op_defs);
+ string cc_file_text, h_file_text;
+ string internal_cc_file_text, internal_h_file_text;
+ // Without ApiDef
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectHasSubstr(h_file_text, "class Foo {");
+ ExpectDoesNotHaveSubstr(h_file_text, "class Foo1");
+ ExpectDoesNotHaveSubstr(h_file_text, "class Foo2");
+
+ // With ApiDef
+ TF_ASSERT_OK(api_def_map.LoadApiDef(api_def));
+ GenerateCcOpFiles(env, op_defs, api_def_map, &h_file_text,
+ &internal_h_file_text);
+ ExpectHasSubstr(h_file_text, "class Foo1");
+ ExpectHasSubstr(h_file_text, "typedef Foo1 Foo2");
+ ExpectDoesNotHaveSubstr(h_file_text, "class Foo {");
+}
+} // namespace
+} // namespace tensorflow