aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/ops/cc_op_gen.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/ops/cc_op_gen.cc')
-rw-r--r--tensorflow/cc/ops/cc_op_gen.cc350
1 files changed, 350 insertions, 0 deletions
diff --git a/tensorflow/cc/ops/cc_op_gen.cc b/tensorflow/cc/ops/cc_op_gen.cc
new file mode 100644
index 0000000000..fdecf967f8
--- /dev/null
+++ b/tensorflow/cc/ops/cc_op_gen.cc
@@ -0,0 +1,350 @@
+// TODO(josh11b): Rewrite function parameter names to avoid C++ keywords
+// or "opts".
+
+#include "tensorflow/cc/ops/cc_op_gen.h"
+
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+#include "tensorflow/core/public/env.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 79;
+
+const char* AttrTypeName(StringPiece attr_type) {
+ static const char* kAttrTypeName[][2] = {
+ {"string", "StringPiece"},
+ {"list(string)", "gtl::ArraySlice<string>"},
+ {"int", "int64"},
+ {"list(int)", "gtl::ArraySlice<int>"},
+ {"float", "float"},
+ {"list(float)", "gtl::ArraySlice<float>"},
+ {"bool", "bool"},
+ {"list(bool)", "gtl::ArraySlice<bool>"},
+ {"type", "DataType"},
+ {"list(type)", "DataTypeSlice"},
+ {"shape", "TensorShape"},
+ {"list(shape)", "gtl::ArraySlice<TensorShape>"},
+ {"tensor", "const Tensor&"},
+ {"list(tensor)", "gtl::ArraySlice<Tensor>"},
+ {"func", "const NameAttrList&"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr_type == kAttrTypeName[i][0]) {
+ return kAttrTypeName[i][1];
+ }
+ }
+ LOG(FATAL) << "Unsupported Attr type: " << attr_type;
+ return "";
+}
+
+// Change: Into:
+// ABC // ABC
+// //
+// DEF // DEF
+string MakeComment(StringPiece text) {
+ string ret;
+ while (!text.empty()) {
+ int last_non_space = -1;
+ int newline;
+ for (newline = 0; newline < static_cast<int>(text.size()); ++newline) {
+ if (text[newline] == '\n') break;
+ if (text[newline] != ' ') last_non_space = newline;
+ }
+ if (last_non_space == -1) {
+ strings::StrAppend(&ret, "//\n");
+ } else {
+ strings::StrAppend(&ret, "// ", text.substr(0, last_non_space + 1), "\n");
+ }
+ text.remove_prefix(newline + 1);
+ }
+ return ret;
+}
+
+void WriteCCOp(const OpDef& op_def, WritableFile* h, WritableFile* cc) {
+ // TODO(josh11b): Better wrapping of comments.
+ string comment;
+ if (op_def.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(op_def.summary(), "\n");
+ if (!op_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", op_def.description(), "\n");
+ }
+ }
+
+ static const string kSingleInputType = "NodeOut";
+ static const string kListInputType = "gtl::ArraySlice<NodeOut>";
+
+ std::vector<string> arg_types;
+ std::vector<string> arg_names;
+
+ strings::StrAppend(&comment, "\nArguments:\n");
+
+ // Map from attr name to the first input arg it is inferred from.
+ std::unordered_map<string, string> inferred_attrs;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ arg_names.emplace_back(arg.name());
+ bool is_list = false;
+
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
+ arg.name());
+ is_list = true;
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
+ is_list = true;
+ }
+ if (is_list) {
+ arg_types.emplace_back(kListInputType);
+ } else {
+ arg_types.emplace_back(kSingleInputType);
+ }
+
+ // TODO(josh11b): Include input type information.
+ StringPiece description = arg.description();
+ if (!description.empty()) {
+ ConsumeEquals(&description);
+ strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
+ arg.description(), "\n");
+ }
+ }
+
+ string options_comment;
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ const auto& attr(op_def.attr(i));
+ // Do not add inferred attrs or attrs with defaults to the C++
+ // function signature.
+ if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
+ if (!attr.has_default_value()) {
+ arg_names.emplace_back(attr.name());
+ arg_types.emplace_back(AttrTypeName(attr.type()));
+ if (!attr.description().empty()) {
+ strings::StrAppend(&comment, "* ", arg_names.back(), ": ",
+ attr.description(), "\n");
+ }
+ } else {
+ strings::StrAppend(&options_comment, " .WithAttr(\"", attr.name(),
+ "\", ", AttrTypeName(attr.type()), "): Defaults to ",
+ SummarizeAttrValue(attr.default_value()), ".\n");
+ if (!attr.description().empty()) {
+ strings::StrAppend(&options_comment, " ", attr.description(),
+ "\n");
+ }
+ }
+ }
+ }
+ CHECK_EQ(arg_names.size(), arg_types.size());
+ strings::StrAppend(&comment, "* opts:\n", options_comment,
+ R"comment( .WithName(StringPiece): Set the Node's name
+ .WithDevice(StringPiece): Set the Node's requested device
+ .WithControlInput(Node*) / .WithControlInputs({Node*, ...}):
+ Add control depencies on the specified Node(s).
+
+Returns a pointer to the created Node)comment");
+
+ // TODO(josh11b): Include output type information.
+ if (op_def.output_arg_size() == 0) {
+ strings::StrAppend(&comment, ".\n");
+ } else if (op_def.output_arg_size() == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ ConsumeEquals(&description);
+ if (description.empty()) {
+ strings::StrAppend(&comment, ".\n");
+ } else {
+ strings::StrAppend(&comment, ", with output:\n", description, "\n");
+ }
+ } else {
+ strings::StrAppend(&comment, ", with outputs:\n");
+ for (int o = 0; o < op_def.output_arg_size(); ++o) {
+ StringPiece description = op_def.output_arg(o).description();
+ ConsumeEquals(&description);
+ if (description.empty()) {
+ strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), "\n");
+ } else {
+ strings::StrAppend(&comment, "* ", op_def.output_arg(o).name(), ": ",
+ description, "\n");
+ }
+ }
+ }
+
+ // Write the header comment.
+ TF_CHECK_OK(h->Append(MakeComment(comment)));
+
+ // Declare the function wrapper.
+ const string prefix = strings::StrCat("Node* ", op_def.name(), "(");
+ string h_rest;
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ strings::StrAppend(&h_rest, arg_types[i], " ", arg_names[i], ", ");
+ }
+ strings::StrAppend(&h_rest, "const GraphDefBuilder::Options& opts");
+ string cc_decl = h_rest;
+ strings::StrAppend(&h_rest, ");");
+ TF_CHECK_OK(h->Append(WordWrap(prefix, h_rest, kRightMargin) + "\n\n"));
+
+ // Define the function wrapper.
+ strings::StrAppend(&cc_decl, ") {");
+ TF_CHECK_OK(cc->Append(WordWrap(prefix, cc_decl, kRightMargin) + "\n"));
+ const string op_name = strings::StrCat(" static const string kOpName = \"",
+ op_def.name(), "\";\n");
+
+ if (arg_types.empty()) {
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(" return SourceOp(kOpName, opts);\n}\n\n"));
+ } else if (arg_types == std::vector<string>({kSingleInputType})) {
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(strings::StrCat(" return UnaryOp(kOpName, ",
+ arg_names[0], ", opts);\n}\n\n")));
+ } else if (arg_types ==
+ std::vector<string>({kSingleInputType, kSingleInputType})) {
+ TF_CHECK_OK(cc->Append(op_name));
+ // TODO(josh11b): Word wrap this if it ever becomes necessary.
+ TF_CHECK_OK(
+ cc->Append(strings::StrCat(" return BinaryOp(kOpName, ", arg_names[0],
+ ", ", arg_names[1], ", opts);\n}\n\n")));
+ } else {
+ TF_CHECK_OK(cc->Append(" if (opts.HaveError()) return nullptr;\n"));
+ TF_CHECK_OK(cc->Append(op_name));
+ TF_CHECK_OK(cc->Append(
+ " NodeBuilder node_builder(opts.GetNameForOp(kOpName), kOpName,\n"
+ " opts.op_registry());\n"));
+ for (size_t i = 0; i < arg_names.size(); ++i) {
+ if (i < static_cast<size_t>(op_def.input_arg_size())) {
+ TF_CHECK_OK(cc->Append(
+ strings::StrCat(" node_builder.Input(", arg_names[i], ");\n")));
+ } else {
+ TF_CHECK_OK(
+ cc->Append(strings::StrCat(" node_builder.Attr(\"", arg_names[i],
+ "\", ", arg_names[i], ");\n")));
+ }
+ }
+ TF_CHECK_OK(
+ cc->Append(" return opts.FinalizeBuilder(&node_builder);\n"
+ "}\n\n"));
+ }
+}
+
+// Converts:
+// bazel-out/.../genfiles/XX
+// to: XX.
+string GetPath(const std::string& dot_h_fname) {
+ auto pos = dot_h_fname.find("/genfiles/");
+ if (pos == string::npos) return dot_h_fname;
+ // - 1 account for the terminating null character (\0) in "/genfiles/".
+ return dot_h_fname.substr(pos + sizeof("/genfiles/") - 1);
+}
+
+// Converts:
+// cc/ops/gen_foo_ops.h
+// to:
+// CC_OPS_GEN_FOO_OPS_H_
+string ToGuard(const std::string& path) {
+ string guard;
+ guard.reserve(path.size() + 1); // + 1 -> trailing _
+ for (const char c : path) {
+ if (c >= 'A' && c <= 'Z') {
+ guard += c;
+ } else if (c >= 'a' && c <= 'z') {
+ guard += c + 'A' - 'a';
+ } else {
+ guard += '_';
+ }
+ }
+ guard += '_';
+ return guard;
+}
+
+} // namespace
+
+void WriteCCOps(const OpList& ops, const std::string& dot_h_fname,
+ const std::string& dot_cc_fname) {
+ Env* env = Env::Default();
+ WritableFile* h = nullptr;
+ WritableFile* cc = nullptr;
+ TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h));
+ TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc));
+
+ // .h Header
+ const string include = GetPath(dot_h_fname);
+ const string guard = ToGuard(include);
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ Status s;
+ s = h->Append(
+ strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
+ "#ifndef ",
+ guard,
+ "\n"
+ "#define ",
+ guard, R"header(
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/public/tensor.h"
+#include "tensorflow/core/public/tensor_shape.h"
+
+namespace tensorflow {
+namespace ops {
+
+// These add a node to the graph from opts.
+//
+// Note for "NodeOut" inputs, you will typically either pass
+// * a {Node*, int index} (to pass the index-th output of that node), or
+// * a Node* (to pass the first output of that node).
+
+
+)header"));
+ TF_CHECK_OK(s);
+ // .cc Header
+ s = cc->Append(
+ strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n"
+ "#include \"",
+ include, R"header("
+
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+namespace ops {
+
+)header"));
+ TF_CHECK_OK(s);
+
+ for (const auto& op_def : ops.op()) {
+ WriteCCOp(op_def, h, cc);
+ }
+
+ // .h Footer
+
+ s = h->Append(strings::StrCat(R"footer(} // namespace ops
+} // namespace tensorflow
+
+#endif // )footer",
+ guard, "\n"));
+ TF_CHECK_OK(s);
+
+ // .cc Footer
+
+ s = cc->Append(R"footer(} // namespace ops
+} // namespace tensorflow
+)footer");
+ TF_CHECK_OK(s);
+
+ TF_CHECK_OK(cc->Close());
+ TF_CHECK_OK(h->Close());
+}
+
+} // namespace tensorflow