/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include #include #include #include "tensorflow/cc/framework/cc_op_gen.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb_text.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/version.h" namespace tensorflow { namespace { const int kRightMargin = 79; // Converts: // bazel-out/.../genfiles/(external/YYY/)?XX // to: XX. string GetPath(const string& dot_h_fname) { auto pos = dot_h_fname.find("/genfiles/"); string result = dot_h_fname; if (pos != string::npos) { // - 1 account for the terminating null character (\0) in "/genfiles/". result = dot_h_fname.substr(pos + sizeof("/genfiles/") - 1); } if (result.size() > sizeof("external/") && result.compare(0, sizeof("external/") - 1, "external/") == 0) { result = result.substr(sizeof("external/") - 1); pos = result.find("/"); if (pos != string::npos) { result = result.substr(pos + 1); } } return result; } // Converts: some/path/to/file.xx // to: file // (note that suffix is removed) string GetFilename(const string& path) { size_t slash_pos = path.rfind('/'); if (slash_pos == path.npos) slash_pos = -1; size_t dot_pos = path.rfind('.'); return path.substr(slash_pos + 1, dot_pos - (slash_pos + 1)); } // Converts: // cc/ops/gen_foo_ops.h // to: // CC_OPS_GEN_FOO_OPS_H_ string ToGuard(const string& path) { string guard; guard.reserve(path.size() + 1); // + 1 -> trailing _ for (const char c : path) { if (c >= 'A' && c <= 'Z') { guard += c; } else if (c >= 'a' && c <= 'z') { guard += c + 'A' - 'a'; } else { guard += '_'; } } guard += '_'; return guard; } // Converts: some_name_xyz // to: Some Name Xyz string ToTitle(const string& name) { string title = name; for (int i = 0; i < title.size(); ++i) { if (title[i] == '_') title[i] = ' '; } str_util::TitlecaseString(&title, " "); return title; } // Change: Into: // ABC /// ABC // /// // DEF /// DEF string MakeComment(StringPiece text, StringPiece indent) { string ret; while (!text.empty()) { int last_non_space = -1; int newline; for (newline = 0; newline < static_cast(text.size()); ++newline) { if (text[newline] == '\n') break; if (text[newline] != ' ') last_non_space = newline; } if (last_non_space == -1) { strings::StrAppend(&ret, indent, "///\n"); } else { strings::StrAppend(&ret, indent, "/// ", text.substr(0, last_non_space + 1), "\n"); } text.remove_prefix(newline + 1); } return ret; } string PrintString(const string& str) { return strings::StrCat("\"", str_util::CEscape(str), "\""); } string PrintTensorShape(const TensorShapeProto& shape_proto) { PartialTensorShape shape(shape_proto); if (shape.IsIdenticalTo(PartialTensorShape())) { return "::tensorflow::PartialTensorShape() /* unknown */"; } string ret = "{"; for (int d = 0; d < shape.dims(); ++d) { if (d > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, shape.dim_size(d)); } strings::StrAppend(&ret, "}"); return ret; } template string PrintArray(int64 num_elts, const T* array) { string ret; for (int64 i = 0; i < num_elts; ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, array[i]); } return ret; } string PrintTensor(const TensorProto& tensor_proto) { Tensor t(tensor_proto.dtype()); CHECK(t.FromProto(tensor_proto)); const int64 num_elts = t.NumElements(); switch (t.dtype()) { case DT_FLOAT: return PrintArray(num_elts, t.flat().data()); case DT_DOUBLE: return PrintArray(num_elts, t.flat().data()); case DT_INT32: return PrintArray(num_elts, t.flat().data()); case DT_UINT8: case DT_QUINT8: return PrintArray(num_elts, t.flat().data()); case DT_UINT16: case DT_QUINT16: return PrintArray(num_elts, t.flat().data()); case DT_INT16: case DT_QINT16: return PrintArray(num_elts, t.flat().data()); case DT_INT8: case DT_QINT8: return PrintArray(num_elts, t.flat().data()); case DT_INT64: return PrintArray(num_elts, t.flat().data()); case DT_BOOL: return PrintArray(num_elts, t.flat().data()); case DT_STRING: { string ret; for (int64 i = 0; i < num_elts; ++i) { if (i > 0) strings::StrAppend(&ret, " "); strings::StrAppend(&ret, str_util::CEscape(t.flat()(i))); } return ret; } default: { LOG(FATAL) << "Not handling type " << EnumName_DataType(t.dtype()); return string(); } } } string PrintTensorProto(const TensorProto& proto) { return strings::StrCat("Input::Initializer(", "{", PrintTensor(proto), "}, ", PrintTensorShape(proto.tensor_shape()), ").AsTensorProto()"); } string PrintAttrValue(const string& op, const AttrValue& attr_value) { switch (attr_value.value_case()) { case AttrValue::kS: return PrintString(attr_value.s()); case AttrValue::kI: return strings::StrCat(attr_value.i()); case AttrValue::kF: { const float f = attr_value.f(); return strings::StrCat(attr_value.f(), floorf(f) == f ? ".0" : "", "f"); } case AttrValue::kB: return attr_value.b() ? "true" : "false"; case AttrValue::kType: return EnumName_DataType(attr_value.type()); case AttrValue::kShape: return PrintTensorShape(attr_value.shape()); case AttrValue::kTensor: return PrintTensorProto(attr_value.tensor()); case AttrValue::kList: { string ret = "{"; if (attr_value.list().s_size() > 0) { for (int i = 0; i < attr_value.list().s_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, PrintString(attr_value.list().s(i))); } } else if (attr_value.list().i_size() > 0) { for (int i = 0; i < attr_value.list().i_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, attr_value.list().i(i)); } } else if (attr_value.list().f_size() > 0) { for (int i = 0; i < attr_value.list().f_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); const float f = attr_value.list().f(i); strings::StrAppend(&ret, f, floorf(f) == f ? ".0" : "", "f"); } } else if (attr_value.list().b_size() > 0) { for (int i = 0; i < attr_value.list().b_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, attr_value.list().b(i) ? "true" : "false"); } } else if (attr_value.list().type_size() > 0) { for (int i = 0; i < attr_value.list().type_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, EnumName_DataType(attr_value.list().type(i))); } } else if (attr_value.list().shape_size() > 0) { for (int i = 0; i < attr_value.list().shape_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, PrintTensorShape(attr_value.list().shape(i))); } } else if (attr_value.list().tensor_size() > 0) { for (int i = 0; i < attr_value.list().tensor_size(); ++i) { if (i > 0) strings::StrAppend(&ret, ", "); strings::StrAppend(&ret, PrintTensorProto(attr_value.list().tensor(i))); } } strings::StrAppend(&ret, "}"); return ret; } default: LOG(FATAL) << "Unsupported Attr type: " << op << " " << attr_value.value_case(); } return ""; // Prevent missing return warning } bool IsEmptyList(const AttrValue::ListValue& list) { return list.s_size() == 0 && list.i_size() == 0 && list.f_size() == 0 && list.b_size() == 0 && list.type_size() == 0 && list.shape_size() == 0 && list.tensor_size() == 0; } string ToCamelCase(const string& str) { string result; const char joiner = '_'; size_t i = 0; bool cap = true; while (i < str.size()) { const char c = str[i++]; if (c == joiner) { cap = true; } else if (cap) { result += toupper(c); cap = false; } else { result += c; } } return result; } // Returns a pair. The string is the C++ type name to be used for // attr_type when defining an object of that type. The bool is a flag to // indicate whether to treat the type as const when accepting the C++ type as an // argument to a function. std::pair AttrTypeName(StringPiece attr_type) { static const auto* attr_type_map = new std::unordered_map, StringPieceHasher>{ {"string", {"StringPiece", false}}, {"list(string)", {"gtl::ArraySlice", true}}, {"int", {"int64", false}}, {"list(int)", {"gtl::ArraySlice", true}}, {"float", {"float", false}}, {"list(float)", {"gtl::ArraySlice", true}}, {"bool", {"bool", false}}, {"list(bool)", {"gtl::ArraySlice", true}}, {"type", {"DataType", false}}, {"list(type)", {"DataTypeSlice", true}}, {"shape", {"PartialTensorShape", false}}, {"list(shape)", {"gtl::ArraySlice", true}}, {"tensor", {"TensorProto", true}}, {"list(tensor)", {"gtl::ArraySlice", true}}, {"func", {"NameAttrList", true}}, }; auto entry = attr_type_map->find(attr_type); if (entry == attr_type_map->end()) { LOG(FATAL) << "Unsupported Attr type: " << attr_type; return {"", false}; } return entry->second; } const char* ListElementTypeName(StringPiece attr_type) { static const auto* attr_list_type_map = new std::unordered_map{ {"list(string)", "string"}, {"list(int)", "int"}, {"list(float)", "float"}, {"list(bool)", "bool"}, {"list(type)", "DataType"}, {"list(shape)", "PartialTensorShape"}, {"list(tensor)", "TensorProto"}, }; auto entry = attr_list_type_map->find(attr_type); if (entry == attr_list_type_map->end()) { LOG(FATAL) << "Unsupported or non-list Attr type: " << attr_type; return ""; } return entry->second; } bool IsCPPKeyword(StringPiece name) { static const std::unordered_set // Keywords obtained from http://en.cppreference.com/w/cpp/keyword kCPPReserved{ "alignas", "alignof", "and", "and_eq", "asm", "atomic_cancel", "atomic_commit", "atomic_noexcept", "auto", "bitand", "bitor", "bool", "break", "case", "catch", "char", "char16_t", "char32_t", "class", "compl", "concept", "const", "const_cast", "constexpr", "continue", "decltype", "default", "delete", "do", "double", "dynamic_cast", "else", "enum", "explicit", "export", "extern", "false", "final", "float", "for", "friend", "goto", "if", "import", "inline", "int", "long", "module", "mutable", "namespace", "new", "noexcept", "not", "not_eq", "nullptr", "operator", "or", "or_eq", "override", "private", "protected", "public", "register", "reinterpret_cast", "requires", "return", "short", "signed", "sizeof", "static", "static_assert", "static_cast", "struct", "switch", "synchronized", "template", "this", "thread_local", "throw", "true", "try", "typedef", "typeid", "typename", "union", "unsigned", "using", "virtual", "void", "volatile", "wchar_t", "while", "xor", "xor_eq", // The following are not C++ keywords, but names of local variables // and parameters used in the op constructor. Treating them as // keywords, so that other parameter names don't conflict with these. "builder", "node", "ret", "scope", "unique_name", }; return kCPPReserved.count(name) > 0; } string AvoidCPPKeywords(StringPiece name) { if (IsCPPKeyword(name)) { return strings::StrCat(name, "_"); } return string(name); } void InferArgAttributes(const OpDef::ArgDef& arg, std::unordered_map* inferred_attrs) { if (!arg.type_attr().empty()) { gtl::InsertIfNotPresent(inferred_attrs, arg.type_attr(), arg.name()); } else if (!arg.type_list_attr().empty()) { gtl::InsertIfNotPresent(inferred_attrs, arg.type_list_attr(), arg.name()); } if (!arg.number_attr().empty()) { gtl::InsertIfNotPresent(inferred_attrs, arg.number_attr(), arg.name()); } } void InferOpAttributes( const OpDef& op_def, std::unordered_map* inferred_input_attrs) { for (int i = 0; i < op_def.input_arg_size(); ++i) { const auto& arg(op_def.input_arg(i)); InferArgAttributes(arg, inferred_input_attrs); } } bool ArgIsList(const OpDef::ArgDef& arg) { return !arg.type_list_attr().empty() || !arg.number_attr().empty(); } bool HasOptionalAttrs( const ApiDef& api_def, const std::unordered_map& inferred_input_attrs) { for (int i = 0; i < api_def.attr_size(); ++i) { const auto& attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) == inferred_input_attrs.end()) && attr.has_default_value()) { return true; } } return false; } struct OpInfo { // graph_op_def: The OpDef used by the runtime, has the names that // must be used when calling NodeBuilder. // interface_op_def: The OpDef used in the interface in the generated // code, with possibly overridden names and defaults. explicit OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases); string GetOpAttrStruct() const; string GetConstructorDecl(StringPiece op_name_prefix, bool include_attr) const; void WriteClassDecl(WritableFile* h) const; void GetOutput(string* out) const; string GetConstructorBody() const; void WriteClassDef(WritableFile* cc) const; string op_name; std::vector arg_types; std::vector arg_names; std::vector output_types; std::vector output_names; std::vector is_list_output; bool has_optional_attrs; string comment; const OpDef& graph_op_def; const ApiDef& api_def; const std::vector& aliases; // Map from type attribute to corresponding original argument name. std::unordered_map inferred_input_attrs; }; OpInfo::OpInfo(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases) : graph_op_def(graph_op_def), api_def(api_def), aliases(aliases) { op_name = api_def.endpoint(0).name(); InferOpAttributes(graph_op_def, &inferred_input_attrs); has_optional_attrs = HasOptionalAttrs(api_def, inferred_input_attrs); arg_types.push_back("const ::tensorflow::Scope&"); arg_names.push_back("scope"); if (graph_op_def.has_deprecation()) { if (!api_def.summary().empty()) { comment = strings::StrCat(api_def.summary(), "\n"); } strings::StrAppend(&comment, "DEPRECATED at GraphDef version ", graph_op_def.deprecation().version(), ":\n", graph_op_def.deprecation().explanation(), ".\n"); } else if (api_def.summary().empty()) { comment = "TODO: add doc.\n"; } else { comment = strings::StrCat(api_def.summary(), "\n"); } if (!api_def.description().empty()) { strings::StrAppend(&comment, "\n", api_def.description(), "\n"); } strings::StrAppend(&comment, "\nArguments:\n* scope: A Scope object\n"); // Process inputs for (int i = 0; i < api_def.arg_order_size(); ++i) { const auto& arg = *FindInputArg(api_def.arg_order(i), graph_op_def); const auto& api_def_arg = *FindInputArg(api_def.arg_order(i), api_def); arg_types.push_back(strings::StrCat( "::tensorflow::", ArgIsList(arg) ? "InputList" : "Input")); arg_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); // TODO(keveman): Include input type information. StringPiece description = api_def_arg.description(); if (!description.empty()) { ConsumeEquals(&description); strings::StrAppend(&comment, "* ", AvoidCPPKeywords(api_def_arg.rename_to()), ": ", api_def_arg.description(), "\n"); } } // Process attrs string required_attrs_comment; string optional_attrs_comment; for (int i = 0; i < graph_op_def.attr_size(); ++i) { // ApiDef attributes must be in the same order as in OpDef since // we initialize ApiDef based on OpDef. const auto& attr(graph_op_def.attr(i)); const auto& api_def_attr(api_def.attr(i)); CHECK_EQ(attr.name(), api_def_attr.name()); // Skip inferred arguments if (inferred_input_attrs.count(attr.name()) > 0) continue; const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; string attr_name = AvoidCPPKeywords(api_def_attr.rename_to()); string attr_comment; if (!api_def_attr.description().empty()) { // TODO(keveman): Word wrap and indent this, to handle multi-line // descriptions. strings::StrAppend(&attr_comment, "* ", attr_name, ": ", api_def_attr.description(), "\n"); } if (api_def_attr.has_default_value()) { strings::StrAppend(&optional_attrs_comment, attr_comment); } else { strings::StrAppend(&required_attrs_comment, attr_comment); arg_types.push_back(strings::StrCat( use_const ? "const " : "", attr_type_name, use_const ? "&" : "")); arg_names.push_back(attr_name); } } strings::StrAppend(&comment, required_attrs_comment); if (!optional_attrs_comment.empty()) { strings::StrAppend(&comment, "\nOptional attributes (see `Attrs`):\n"); strings::StrAppend(&comment, optional_attrs_comment); } // Process outputs for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { // ApiDef arguments must be in the same order as in OpDef since // we initialize ApiDef based on OpDef. const auto& arg = graph_op_def.output_arg(i); const auto& api_def_arg(api_def.out_arg(i)); CHECK_EQ(arg.name(), api_def_arg.name()); bool is_list = ArgIsList(arg); output_types.push_back( strings::StrCat("::tensorflow::", is_list ? "OutputList" : "Output")); output_names.push_back(AvoidCPPKeywords(api_def_arg.rename_to())); is_list_output.push_back(is_list); } strings::StrAppend(&comment, "\nReturns:\n"); if (graph_op_def.output_arg_size() == 0) { // No outputs. strings::StrAppend(&comment, "* the created `Operation`\n"); } else if (graph_op_def.output_arg_size() == 1) { // One output if (is_list_output[0]) { strings::StrAppend(&comment, "* `OutputList`: "); } else { strings::StrAppend(&comment, "* `Output`: "); } if (api_def.out_arg(0).description().empty()) { strings::StrAppend(&comment, "The ", api_def.out_arg(0).name(), " tensor.\n"); } else { // TODO(josh11b): Word wrap this. strings::StrAppend(&comment, api_def.out_arg(0).description(), "\n"); } } else { // Multiple outputs. for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { if (is_list_output[i]) { strings::StrAppend(&comment, "* `OutputList`"); } else { strings::StrAppend(&comment, "* `Output`"); } strings::StrAppend(&comment, " ", output_names[i]); if (api_def.out_arg(i).description().empty()) { strings::StrAppend(&comment, "\n"); } else { // TODO(josh11b): Word wrap this. strings::StrAppend(&comment, ": ", api_def.out_arg(i).description(), "\n"); } } } if (!aliases.empty()) { strings::StrAppend(&comment, "\nAliases:\n"); for (const auto& alias : aliases) { strings::StrAppend(&comment, "* ", alias, "\n"); } } comment = MakeComment(comment, ""); } string OpInfo::GetOpAttrStruct() const { string struct_fields; string setters; string defaults_static_storage; for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); const auto& api_def_attr(api_def.attr(i)); // If attr will be inferred or it doesn't have a default value, don't // add it to the struct. if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = strings::StrCat(camel_case_name, suffix, "(", use_const ? "const " : "", attr_type_name, use_const ? "&" : ""); string attr_comment; if (!api_def_attr.description().empty()) { strings::StrAppend(&attr_comment, api_def_attr.description(), "\n\n"); } strings::StrAppend(&attr_comment, "Defaults to ", SummarizeAttrValue(api_def_attr.default_value()), "\n"); attr_comment = MakeComment(attr_comment, " "); strings::StrAppend(&setters, attr_comment); strings::StrAppend(&setters, " TF_MUST_USE_RESULT Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&setters, " Attrs ret = *this;\n"); strings::StrAppend(&setters, " ret.", api_def_attr.rename_to(), "_ = x;\n"); strings::StrAppend(&setters, " return ret;\n }\n\n"); string field_initiliazer; auto& default_value = api_def_attr.default_value(); if (default_value.value_case() == AttrValue::kList && !IsEmptyList(default_value.list())) { // Non-empty lists need static storage for their defaults. Define a // function with static local variable that stores the array. strings::StrAppend(&defaults_static_storage, " static ", attr_type_name, " Default_", api_def_attr.rename_to(), "() {\n"); strings::StrAppend( &defaults_static_storage, " static const ", ListElementTypeName(attr.type()), " kStorage[] = ", PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()), ";\n"); strings::StrAppend(&defaults_static_storage, " return ", attr_type_name, "(kStorage);\n }\n"); // Set the field_initializer to call the defined function. strings::StrAppend(&field_initiliazer, "Default_", api_def_attr.rename_to(), "()"); } else { field_initiliazer = PrintAttrValue(graph_op_def.name(), api_def_attr.default_value()); } strings::StrAppend(&struct_fields, " ", attr_type_name, " ", api_def_attr.rename_to(), "_ = ", field_initiliazer, ";\n"); } if (struct_fields.empty()) { return ""; } string attrs_comment = strings::StrCat("Optional attribute setters for ", op_name, "\n"); string struct_decl = MakeComment(attrs_comment, " "); strings::StrAppend(&struct_decl, " struct Attrs {\n"); strings::StrAppend(&struct_decl, setters, struct_fields); if (!defaults_static_storage.empty()) { strings::StrAppend(&struct_decl, " private:\n", defaults_static_storage); } strings::StrAppend(&struct_decl, " };\n"); return struct_decl; } string OpInfo::GetConstructorDecl(StringPiece op_name_prefix, bool include_attr) const { const string prefix = strings::StrCat(op_name_prefix, op_name, "("); string c_decl; for (int i = 0; i < arg_types.size(); ++i) { if (i > 0) strings::StrAppend(&c_decl, ", "); strings::StrAppend(&c_decl, arg_types[i], " ", arg_names[i]); } if (include_attr && has_optional_attrs) { strings::StrAppend(&c_decl, ", const ", op_name, "::Attrs& attrs"); } strings::StrAppend(&c_decl, ")"); return WordWrap(prefix, c_decl, kRightMargin); } void OpInfo::WriteClassDecl(WritableFile* h) const { string class_decl = comment; strings::StrAppend(&class_decl, "class ", op_name, " {\n"); strings::StrAppend(&class_decl, " public:\n"); if (has_optional_attrs) { strings::StrAppend(&class_decl, GetOpAttrStruct()); } strings::StrAppend(&class_decl, " ", GetConstructorDecl("", /* include_attr */ false), ";\n"); if (has_optional_attrs) { strings::StrAppend(&class_decl, " ", GetConstructorDecl("", /* include_attr */ true), ";\n"); } if (output_types.empty()) { // Allow casting this class to Operation. strings::StrAppend(&class_decl, " operator ::tensorflow::Operation() const { " "return operation; }\n"); } else if (output_types.size() == 1) { if (is_list_output[0]) { // Write the subscript operator, allowing out[i] for the list-typed // output. strings::StrAppend(&class_decl, " ::tensorflow::Output operator[](size_t index) " "const { return ", output_names[0], "[index]; }\n\n"); } else { // Write type cast functions, allowing casting this class to Input and // Output. strings::StrAppend(&class_decl, " operator ::tensorflow::Output() const { return ", output_names[0], "; }\n"); strings::StrAppend(&class_decl, " operator ::tensorflow::Input() const { return ", output_names[0], "; }\n"); // Write node() to get the Node* directly. strings::StrAppend(&class_decl, " ::tensorflow::Node* node() const { return ", output_names[0], ".node(); }\n"); } } // Add the static functions to set optional attrs if (has_optional_attrs) { strings::StrAppend(&class_decl, "\n"); for (int i = 0; i < graph_op_def.attr_size(); ++i) { const auto& attr(graph_op_def.attr(i)); const auto& api_def_attr(api_def.attr(i)); if ((inferred_input_attrs.find(attr.name()) != inferred_input_attrs.end()) || !api_def_attr.has_default_value()) { continue; } const auto entry = AttrTypeName(attr.type()); const auto attr_type_name = entry.first; const bool use_const = entry.second; const string camel_case_name = ToCamelCase(api_def_attr.rename_to()); const string suffix = (camel_case_name == op_name || camel_case_name == "Attrs") ? "_" : ""; const string attr_func_def = strings::StrCat( camel_case_name, suffix, "(", use_const ? "const " : "", attr_type_name, use_const ? "&" : ""); strings::StrAppend(&class_decl, " static Attrs ", attr_func_def, " x) {\n"); strings::StrAppend(&class_decl, " return Attrs().", camel_case_name, suffix, "(x);\n"); strings::StrAppend(&class_decl, " }\n"); } } strings::StrAppend(&class_decl, "\n Operation operation;\n"); for (int i = 0; i < output_types.size(); ++i) { strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i], ";\n"); } strings::StrAppend(&class_decl, "};\n"); if (!aliases.empty()) { for (const auto& alias : aliases) { strings::StrAppend(&class_decl, "typedef ", op_name, " ", alias, ";\n"); } } strings::StrAppend(&class_decl, "\n"); TF_CHECK_OK(h->Append(class_decl)); } void OpInfo::GetOutput(string* out) const { const string scope_str = arg_names[0]; string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); strings::StrAppend(out, " this->operation = Operation(ret);\n"); // No outputs. if (graph_op_def.output_arg_size() == 0) { strings::StrAppend(out, " return;\n"); return; } if (graph_op_def.output_arg_size() == 1) { // One output, no need for NameRangeMap if (is_list_output[0]) { strings::StrAppend(out, " for (int32 i = 0; i < ret->num_outputs(); ++i)\n"); strings::StrAppend(out, " this->", output_names[0], ".push_back(Output(ret, i));\n"); } else { strings::StrAppend(out, " this->", output_names[0], " = Output(ret, 0);\n"); } return; } strings::StrAppend(out, " ::tensorflow::NameRangeMap _outputs_range;\n"); strings::StrAppend(out, " ::tensorflow::Status _status_ = " "::tensorflow::NameRangesForNode(*ret, ret->op_def(), " "nullptr, &_outputs_range);\n"); strings::StrAppend(out, " if (!_status_.ok()) {\n", " ", scope_str, ".UpdateStatus(_status_);\n", " return;\n"); strings::StrAppend(out, " }\n\n"); for (int i = 0; i < graph_op_def.output_arg_size(); ++i) { const string arg_range = strings::StrCat( "_outputs_range[\"", graph_op_def.output_arg(i).name(), "\"]"); if (is_list_output[i]) { strings::StrAppend(out, " for (int32 i = ", arg_range, ".first; i < ", arg_range, ".second; ++i)\n"); strings::StrAppend(out, " this->", output_names[i], ".push_back(Output(ret, i));\n"); } else { strings::StrAppend(out, " this->", output_names[i], " = Output(ret, ", arg_range, ".first);\n"); } } } string OpInfo::GetConstructorBody() const { const string scope_str = arg_names[0]; string body; string return_on_error = strings::StrCat("if (!", scope_str, ".ok()) return;"); strings::StrAppend(&body, " ", return_on_error, "\n"); for (int i = 0; i < graph_op_def.input_arg_size(); ++i) { const auto& arg(graph_op_def.input_arg(i)); const auto& api_def_arg(api_def.in_arg(i)); strings::StrAppend( &body, " auto _", api_def_arg.rename_to(), " = ::tensorflow::ops::", ArgIsList(arg) ? "AsNodeOutList" : "AsNodeOut", "(", scope_str, ", ", AvoidCPPKeywords(api_def_arg.rename_to()), ");\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); } strings::StrAppend(&body, " ::tensorflow::Node* ret;\n"); strings::StrAppend(&body, " const auto unique_name = ", scope_str, ".GetUniqueNameForOp(\"", op_name, "\");\n"); strings::StrAppend( &body, " auto builder = ::tensorflow::NodeBuilder(unique_name, \"", graph_op_def.name(), "\")\n"); const string spaces = " "; for (int i = 0; i < api_def.in_arg_size(); ++i) { const auto& arg(api_def.in_arg(i)); strings::StrAppend(&body, spaces, ".Input(_", arg.rename_to(), ")\n"); } for (int i = 0; i < api_def.attr_size(); ++i) { const auto& graph_attr(graph_op_def.attr(i)); const auto& api_def_attr(api_def.attr(i)); if (inferred_input_attrs.find(api_def_attr.name()) != inferred_input_attrs.end()) { continue; } const string attr_name = api_def_attr.has_default_value() ? strings::StrCat("attrs.", api_def_attr.rename_to(), "_") : AvoidCPPKeywords(api_def_attr.rename_to()); strings::StrAppend(&body, spaces, ".Attr(\"", graph_attr.name(), "\", ", attr_name, ")\n"); } strings::StrAppend(&body, " ;\n"); strings::StrAppend(&body, " ", scope_str, ".UpdateBuilder(&builder);\n"); strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(builder.Finalize(", scope_str, ".graph(), &ret));\n"); strings::StrAppend(&body, " ", return_on_error, "\n"); strings::StrAppend(&body, " ", scope_str, ".UpdateStatus(", scope_str, ".DoShapeInference(ret));\n"); GetOutput(&body); return body; } void OpInfo::WriteClassDef(WritableFile* cc) const { string class_def; strings::StrAppend(&class_def, GetConstructorDecl(strings::StrCat(op_name, "::"), /* include_attr */ true), " {\n"); strings::StrAppend(&class_def, GetConstructorBody()); strings::StrAppend(&class_def, "}\n\n"); if (has_optional_attrs) { strings::StrAppend(&class_def, GetConstructorDecl(strings::StrCat(op_name, "::"), /* include_attr */ false)); strings::StrAppend(&class_def, "\n : ", op_name, "("); int i = 0; for (; i < arg_names.size(); ++i) { if (i > 0) strings::StrAppend(&class_def, ", "); strings::StrAppend(&class_def, arg_names[i]); } if (i > 0) strings::StrAppend(&class_def, ", "); strings::StrAppend(&class_def, op_name, "::Attrs()"); strings::StrAppend(&class_def, ") {}\n\n"); } TF_CHECK_OK(cc->Append(class_def)); } void WriteCCOp(const OpDef& graph_op_def, const ApiDef& api_def, const std::vector& aliases, WritableFile* h, WritableFile* cc) { OpInfo op_info(graph_op_def, api_def, aliases); op_info.WriteClassDecl(h); op_info.WriteClassDef(cc); } void StartFiles(bool internal, const string& dot_h_fname, WritableFile* h, WritableFile* cc, string* op_header_guard) { const string header = R"header(// This file is MACHINE GENERATED! Do not edit. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" )header"; // TODO(keveman): Make namespaces configurable. const string namespace_begin = internal ? R"namespace( namespace tensorflow { namespace ops { namespace internal { // NOTE: This namespace has internal TensorFlow details that // are not part of TensorFlow's public API. )namespace" : R"namespace( namespace tensorflow { namespace ops { )namespace"; const string op_header = GetPath(dot_h_fname); *op_header_guard = ToGuard(op_header); const string cc_header = strings::StrCat( R"include(// This file is MACHINE GENERATED! Do not edit. #include "tensorflow/cc/ops/const_op.h" )include", "#include \"", op_header, "\"\n", namespace_begin); const string filename = GetFilename(dot_h_fname); const string doxygen = strings::StrCat("/// @defgroup ", filename, " ", ToTitle(filename), "\n", "/// @{\n\n"); TF_CHECK_OK(h->Append( strings::StrCat("// This file is MACHINE GENERATED! Do not edit.\n\n" "#ifndef ", *op_header_guard, "\n" "#define ", *op_header_guard, "\n\n"))); TF_CHECK_OK(h->Append(header)); TF_CHECK_OK(h->Append(namespace_begin)); TF_CHECK_OK(h->Append(doxygen)); TF_CHECK_OK(cc->Append(cc_header)); } void FinishFiles(bool internal, WritableFile* h, WritableFile* cc, const string& op_header_guard) { const string footer = internal ? R"footer(} // namespace internal } // namespace ops } // namespace tensorflow )footer" : R"footer(/// @} } // namespace ops } // namespace tensorflow )footer"; TF_CHECK_OK(h->Append(footer)); TF_CHECK_OK( h->Append(strings::StrCat("\n#endif ", "// ", op_header_guard, "\n"))); TF_CHECK_OK(cc->Append(footer)); TF_CHECK_OK(cc->Close()); TF_CHECK_OK(h->Close()); } string MakeInternal(const string& fname) { auto dot_pos = fname.rfind('.'); if (dot_pos == string::npos) { return strings::StrCat(fname, "_internal"); } else { return strings::StrCat(fname.substr(0, dot_pos), "_internal", fname.substr(dot_pos)); } } } // namespace void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, const string& dot_h_fname, const string& dot_cc_fname) { Env* env = Env::Default(); // Write the initial boilerplate to the .h and .cc files. std::unique_ptr h = nullptr; std::unique_ptr cc = nullptr; TF_CHECK_OK(env->NewWritableFile(dot_h_fname, &h)); TF_CHECK_OK(env->NewWritableFile(dot_cc_fname, &cc)); string op_header_guard; StartFiles(false, dot_h_fname, h.get(), cc.get(), &op_header_guard); // Create the internal versions of these files for the hidden ops. std::unique_ptr internal_h = nullptr; std::unique_ptr internal_cc = nullptr; const string internal_dot_h_fname = MakeInternal(dot_h_fname); TF_CHECK_OK(env->NewWritableFile(internal_dot_h_fname, &internal_h)); TF_CHECK_OK(env->NewWritableFile(MakeInternal(dot_cc_fname), &internal_cc)); string internal_op_header_guard; StartFiles(true /* internal */, internal_dot_h_fname, internal_h.get(), internal_cc.get(), &internal_op_header_guard); for (const auto& graph_op_def : ops.op()) { // Skip deprecated ops. // TODO(josh11b): If needed, can put them into a "deprecated" namespace // instead of skipping. if (graph_op_def.has_deprecation() && graph_op_def.deprecation().version() <= TF_GRAPH_DEF_VERSION) { continue; } // We use a hand-written wrapper for "Const", since the generated // code depends on it. if (graph_op_def.name() == "Const") continue; const auto* api_def = api_def_map.GetApiDef(graph_op_def.name()); std::vector aliases; if (api_def->visibility() == ApiDef::SKIP) continue; // First endpoint is canonical, the rest are aliases. for (int endpoint_i = 1; endpoint_i < api_def->endpoint_size(); ++endpoint_i) { aliases.push_back(api_def->endpoint(endpoint_i).name()); } if (api_def->visibility() == ApiDef::HIDDEN) { // Write hidden ops to _internal.h and _internal.cc. WriteCCOp(graph_op_def, *api_def, aliases, internal_h.get(), internal_cc.get()); continue; } // This isn't a hidden op, write it to the main files. WriteCCOp(graph_op_def, *api_def, aliases, h.get(), cc.get()); } FinishFiles(false, h.get(), cc.get(), op_header_guard); FinishFiles(true /* internal */, internal_h.get(), internal_cc.get(), internal_op_header_guard); } } // namespace tensorflow