aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/python_op_gen.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/python_op_gen.cc')
-rw-r--r--tensorflow/python/framework/python_op_gen.cc678
1 files changed, 678 insertions, 0 deletions
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
new file mode 100644
index 0000000000..5c1b4462d5
--- /dev/null
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -0,0 +1,678 @@
+#include "tensorflow/python/framework/python_op_gen.h"
+
+#include <stdio.h>
+#include <unordered_map>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/op_def_util.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/port.h"
+
+namespace tensorflow {
+namespace {
+
+const int kRightMargin = 78;
+
+bool IsPythonReserved(const string& s) {
+ static const std::set<string>* const kPythonReserved = new std::set<string>(
+ {// Keywords in Python, from:
+ // import keyword
+ // print keyword.kwlist
+ "and", "as", "assert", "break", "class", "continue", "def", "del",
+ "elif", "else", "except", "exec", "finally", "for", "from", "global",
+ "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",
+ "raise", "return", "try", "while", "with", "yield",
+ // Built-in functions and types in Python, from:
+ // [x for x in dir(__builtins__) if not x[0].islower()]
+ "ArithmeticError", "AssertionError", "AttributeError", "BaseException",
+ "BufferError", "BytesWarning", "DeprecationWarning", "EOFError",
+ "Ellipsis", "EnvironmentError", "Exception", "False",
+ "FloatingPointError", "FutureWarning", "GeneratorExit", "IOError",
+ "ImportError", "ImportWarning", "IndentationError", "IndexError",
+ "KeyError", "KeyboardInterrupt", "LookupError", "MemoryError",
+ "NameError", "None", "NotImplemented", "NotImplementedError", "OSError",
+ "OverflowError", "PendingDeprecationWarning", "ReferenceError",
+ "RuntimeError", "RuntimeWarning", "StandardError", "StopIteration",
+ "SyntaxError", "SyntaxWarning", "SystemError", "SystemExit", "TabError",
+ "True", "TypeError", "UnboundLocalError", "UnicodeDecodeError",
+ "UnicodeEncodeError", "UnicodeError", "UnicodeTranslateError",
+ "UnicodeWarning", "UserWarning", "ValueError", "Warning",
+ "ZeroDivisionError", "__debug__", "__doc__", "__import__", "__name__",
+ "__package__",
+ // Imports and symbols used in the generated code:
+ "_op_def_lib", "text_format", "op_def_pb2", "op_def_library", "ops"});
+
+ return kPythonReserved->count(s) > 0;
+}
+
+// Add a _ to the end of s if necessary to avoid a Python keyword or built-in.
+string AvoidPythonReserved(const string& s) {
+ if (IsPythonReserved(s)) return strings::StrCat(s, "_");
+ return s;
+}
+
+// Indent the first line by "initial" spaces and all following lines
+// by "rest" spaces.
+string Indent(int initial, int rest, StringPiece in) {
+ // TODO(josh11b): Also word-wrapping?
+ string copy(in.data(), in.size());
+ str_util::StripTrailingWhitespace(&copy);
+ std::vector<string> v = str_util::Split(copy, '\n');
+
+ string result;
+ bool first = true;
+ for (const string& line : v) {
+ if (first) {
+ result = strings::StrCat(Spaces(initial), line, "\n");
+ first = false;
+ } else {
+ if (line.empty()) {
+ strings::StrAppend(&result, "\n");
+ } else {
+ strings::StrAppend(&result, Spaces(rest), line, "\n");
+ }
+ }
+ }
+ return result;
+}
+
+// Adds append to *dest, with a space if the first line will be <= width,
+// or a newline otherwise.
+void AppendWithinWidth(string* dest, StringPiece append, int width) {
+ auto first_line = append.find('\n');
+ if (first_line == string::npos) first_line = append.size();
+ if (dest->size() + first_line + 1 /* space */ > static_cast<size_t>(width)) {
+ strings::StrAppend(dest, "\n", append);
+ } else {
+ strings::StrAppend(dest, " ", append);
+ }
+}
+
+void RemoveDescriptionsFromOpDef(OpDef* op_def) {
+ for (int i = 0; i < op_def->input_arg_size(); ++i) {
+ op_def->mutable_input_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->output_arg_size(); ++i) {
+ op_def->mutable_output_arg(i)->clear_description();
+ }
+ for (int i = 0; i < op_def->attr_size(); ++i) {
+ op_def->mutable_attr(i)->clear_description();
+ }
+ op_def->clear_summary();
+ op_def->clear_description();
+}
+
+// Like DataTypeString() but uses the Python names for the
+// float types.
+string PythonDataTypeString(DataType dtype) {
+ switch (dtype) {
+ case DT_FLOAT:
+ return "float32";
+ case DT_DOUBLE:
+ return "float64";
+ default:
+ return DataTypeString(dtype);
+ }
+}
+
+string TypeString(DataType dtype, bool ref) {
+ if (ref) {
+ return strings::StrCat("mutable `", PythonDataTypeString(dtype), "`");
+ } else {
+ return strings::StrCat("`", PythonDataTypeString(dtype), "`");
+ }
+}
+
+string TypeListString(const AttrValue& value) {
+ string ret;
+ for (int t : value.list().type()) {
+ if (!ret.empty()) strings::StrAppend(&ret, ", ");
+ DataType dtype = static_cast<DataType>(t);
+ if (IsRefType(dtype)) {
+ strings::StrAppend(&ret, PythonDataTypeString(RemoveRefType(dtype)),
+ " mutable");
+ } else {
+ strings::StrAppend(&ret, "`", PythonDataTypeString(dtype), "`");
+ }
+ }
+ return ret;
+}
+
+string SingleTensorName(DataType dtype, bool is_ref) {
+ const string type_str = TypeString(dtype, is_ref);
+ return strings::StrCat("A `Tensor` of type ", type_str, ".");
+}
+
+const char kUnknownTensorType[] = {"A `Tensor`."};
+
+string ArgTypeName(const OpDef& op_def, const OpDef::ArgDef& arg,
+ const std::unordered_map<string, string>& inferred_attrs,
+ bool is_output) {
+ if (!arg.number_attr().empty()) {
+ // N Tensors with the same type
+ const string* original_arg =
+ gtl::FindOrNull(inferred_attrs, arg.number_attr());
+ string prefix;
+ if (original_arg == nullptr) {
+ prefix = strings::StrCat("A list of `", arg.number_attr(), "`");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.number_attr(), op_def);
+ if (attr->has_minimum() && attr->minimum() > 0) {
+ prefix = strings::StrCat("A list of at least ", attr->minimum());
+ } else {
+ prefix = "A list of";
+ }
+ } else {
+ prefix = strings::StrCat(
+ "A list with the same number of `Tensor` objects as `",
+ AvoidPythonReserved(*original_arg), "` of");
+ }
+
+ if (arg.type() != DT_INVALID) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ TypeString(arg.type(), arg.is_ref()), ".");
+ } else {
+ original_arg = gtl::FindOrNull(inferred_attrs, arg.type_attr());
+ if (arg.is_ref()) {
+ strings::StrAppend(&prefix, " mutable");
+ }
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " `Tensor` objects of type ",
+ arg.type_attr(), ".");
+ } else if (*original_arg == arg.name()) {
+ const OpDef::AttrDef* attr = FindAttr(arg.type_attr(), op_def);
+ if (attr->has_allowed_values()) {
+ return strings::StrCat(prefix,
+ " `Tensor` objects of the same type in: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type.");
+ }
+ } else {
+ return strings::StrCat(prefix, " `Tensor` objects of the same type as ",
+ AvoidPythonReserved(*original_arg), ".");
+ }
+ }
+ } else if (!arg.type_attr().empty() || !arg.type_list_attr().empty()) {
+ const bool is_list = !arg.type_list_attr().empty();
+ const string attr_name = is_list ? arg.type_list_attr() : arg.type_attr();
+ const OpDef::AttrDef* attr = FindAttr(attr_name, op_def);
+ const string mutable_str = arg.is_ref() ? "mutable " : "";
+ const string prefix =
+ is_list ? strings::StrCat("A list of ", mutable_str, "`Tensor` objects")
+ : strings::StrCat("A ", mutable_str, "`Tensor`");
+ const string* original_arg = gtl::FindOrNull(inferred_attrs, attr_name);
+ if (original_arg == nullptr) {
+ return strings::StrCat(prefix, " of type `", attr_name, "`.");
+ } else if (*original_arg == arg.name()) {
+ if (attr->has_allowed_values()) {
+ if (is_list) {
+ return strings::StrCat(prefix, " with types from: ",
+ TypeListString(attr->allowed_values()), ".");
+ } else {
+ return strings::StrCat(
+ prefix, is_output ? ". Has one of the following types: "
+ : ". Must be one of the following types: ",
+ TypeListString(attr->allowed_values()), ".");
+ }
+ } else {
+ return strings::StrCat(prefix, ".");
+ }
+ } else {
+ return strings::StrCat(prefix,
+ is_output ? ". Has the same type as `"
+ : ". Must have the same type as `",
+ AvoidPythonReserved(*original_arg), "`.");
+ }
+ } else {
+ return SingleTensorName(arg.type(), arg.is_ref());
+ }
+}
+
+void PrintReturns(const OpDef& op_def,
+ const std::vector<string>& output_type_string) {
+ DCHECK_EQ(op_def.output_arg_size(), output_type_string.size());
+ const int num_outs = op_def.output_arg_size();
+ printf("\n Returns:\n");
+ if (num_outs == 0) {
+ printf(" The created Operation.\n");
+ } else {
+ if (num_outs == 1) {
+ StringPiece description = op_def.output_arg(0).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ printf("%s", Indent(4, 4, description).c_str());
+ } else {
+ // Special case of one output, don't use the name of the output unless
+ // there is no description.
+ string desc = output_type_string.empty() ? kUnknownTensorType
+ : output_type_string[0];
+ if (desc == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor type
+ // depends on the input tensor types, just use the output arg
+ // description if we can.
+ if (!description.empty()) {
+ desc = op_def.output_arg(0).description();
+ } else if (!op_def.output_arg(0).name().empty()) {
+ desc = strings::StrCat(" The ", op_def.output_arg(0).name(),
+ " `Tensor`.");
+ }
+ } else if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 4, desc).c_str());
+ }
+ } else {
+ std::vector<string> out_names(num_outs);
+ for (int i = 0; i < num_outs; ++i) {
+ if (!op_def.output_arg(i).name().empty()) {
+ out_names[i] = op_def.output_arg(i).name();
+ } else {
+ out_names[i] = strings::StrCat("output", i);
+ }
+ }
+ printf(" A tuple of `Tensor` objects (%s).\n",
+ str_util::Join(out_names, ", ").c_str());
+ for (int i = 0; i < num_outs; ++i) {
+ string desc = strings::StrCat(out_names[i], ": ");
+ StringPiece description = op_def.output_arg(i).description();
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ strings::StrAppend(&desc, description);
+ } else {
+ const string type = static_cast<size_t>(i) < output_type_string.size()
+ ? output_type_string[i]
+ : kUnknownTensorType;
+ if (!description.empty()) {
+ if (type == kUnknownTensorType) {
+ // Special case where we don't understand how the output tensor
+ // type depends on the input tensor types, so we just use the
+ // output arg description.
+ strings::StrAppend(&desc, description);
+ } else {
+ strings::StrAppend(&desc, type, " ", description);
+ }
+ } else {
+ strings::StrAppend(&desc, type);
+ }
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+ }
+ }
+}
+
+string StringToPython(const string& str) {
+ return strings::StrCat("\"", str_util::CEscape(str), "\"");
+}
+
+string DataTypeToPython(DataType dtype) {
+ return strings::StrCat("tf.", PythonDataTypeString(dtype));
+}
+
+string ShapeToPython(const TensorShapeProto& shape) {
+ string python = "[";
+ for (const auto& dim : shape.dim()) {
+ if (python.size() > 1) strings::StrAppend(&python, ", ");
+ if (!dim.name().empty()) {
+ strings::StrAppend(&python, "(", StringToPython(dim.name()), ", ",
+ dim.size(), ")");
+ } else {
+ strings::StrAppend(&python, dim.size());
+ }
+ }
+ strings::StrAppend(&python, "]");
+ return python;
+}
+
+string AttrListToPython(const AttrValue& value) {
+ string ret;
+ if (value.list().s_size() > 0) {
+ for (int i = 0; i < value.list().s_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, StringToPython(value.list().s(i)));
+ }
+ } else if (value.list().i_size() > 0) {
+ for (int i = 0; i < value.list().i_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().i(i));
+ }
+ } else if (value.list().f_size() > 0) {
+ for (int i = 0; i < value.list().f_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().f(i));
+ }
+ } else if (value.list().b_size() > 0) {
+ for (int i = 0; i < value.list().b_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, value.list().b(i) ? "True" : "False");
+ }
+ } else if (value.list().type_size() > 0) {
+ for (int i = 0; i < value.list().type_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, DataTypeToPython(value.list().type(i)));
+ }
+ } else if (value.list().shape_size() > 0) {
+ for (int i = 0; i < value.list().shape_size(); ++i) {
+ if (i > 0) strings::StrAppend(&ret, ", ");
+ strings::StrAppend(&ret, ShapeToPython(value.list().shape(i)));
+ }
+ }
+ return ret;
+}
+
+string AttrValueToPython(const string& type, const AttrValue& value) {
+ if (type == "string") {
+ return StringToPython(value.s());
+ } else if (type == "int") {
+ return strings::StrCat(value.i());
+ } else if (type == "float") {
+ return strings::StrCat(value.f());
+ } else if (type == "bool") {
+ return value.b() ? "True" : "False";
+ } else if (type == "type") {
+ return DataTypeToPython(value.type());
+ } else if (type == "shape") {
+ return ShapeToPython(value.shape());
+ } else {
+ return strings::StrCat("[", AttrListToPython(value), "]");
+ }
+}
+
+// Requires: ValidateOpDef(op_def).ok()
+void PrintPythonOp(const OpDef& op_def, bool is_hidden, string op_name) {
+ // Map from attr name to the first input arg it is inferred from.
+ std::unordered_map<string, string> inferred_attrs;
+ // This has all the input args followed by those attrs that don't have
+ // defaults.
+ std::vector<string> args_no_default;
+ // The parameters with defaults (these have to be listed after those without).
+ // No input args are included, just attrs and the graph ("g") parameter.
+ std::vector<string> args_with_defaults;
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ args_no_default.push_back(arg.name());
+ if (!arg.type_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_attr(), arg.name());
+ } else if (!arg.type_list_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.type_list_attr(),
+ arg.name());
+ }
+ if (!arg.number_attr().empty()) {
+ gtl::InsertIfNotPresent(&inferred_attrs, arg.number_attr(), arg.name());
+ }
+ }
+ for (int i = 0; i < op_def.attr_size(); ++i) {
+ const auto& attr(op_def.attr(i));
+ // Do not add inferred attrs to the Python function signature.
+ if (inferred_attrs.find(attr.name()) == inferred_attrs.end()) {
+ if (attr.has_default_value()) {
+ args_with_defaults.push_back(attr.name());
+ } else {
+ args_no_default.push_back(attr.name());
+ }
+ }
+ }
+
+ // Save the list of attr parameters (attrs that won't be inferred),
+ // those with defaults go at the end.
+ std::vector<string> attrs;
+ // Get the attrs in the order we want by taking the attrs without defaults
+ // from the end of args_no_default, and adding args_no_default (before
+ // "g" gets added to args_no_default, so it only has attrs).
+ attrs.reserve(args_no_default.size() - op_def.input_arg_size() +
+ args_with_defaults.size());
+ attrs.insert(attrs.end(), args_no_default.begin() + op_def.input_arg_size(),
+ args_no_default.end());
+ attrs.insert(attrs.end(), args_with_defaults.begin(),
+ args_with_defaults.end());
+
+ std::vector<string> param_names;
+ param_names.reserve(args_no_default.size() + args_with_defaults.size());
+ string parameters;
+ for (const string& name : args_no_default) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param);
+ param_names.push_back(param);
+ }
+ for (const string& name : args_with_defaults) {
+ if (!parameters.empty()) strings::StrAppend(&parameters, ", ");
+ const string param = AvoidPythonReserved(name);
+ strings::StrAppend(&parameters, param, "=None");
+ param_names.push_back(param);
+ }
+ const bool has_args = args_no_default.size() + args_with_defaults.size() > 0;
+
+ // Print: def Function(parameters):
+ const string lower_op_name = strings::StrCat(is_hidden ? "_" : "", op_name);
+
+ const string def_prefix = strings::StrCat("def ", lower_op_name, "(");
+ const string def_suffix =
+ strings::StrCat(parameters, has_args ? ", " : "", "name=None):");
+
+ printf("%s\n", WordWrap(def_prefix, def_suffix, kRightMargin).c_str());
+
+ // Format the Op's descriptions so that it can be a Python docstring.
+ string comment;
+ if (op_def.summary().empty()) {
+ comment = "TODO: add doc.\n";
+ } else {
+ comment = strings::StrCat(op_def.summary(), "\n");
+ if (!op_def.description().empty()) {
+ strings::StrAppend(&comment, "\n", Indent(2, 2, op_def.description()));
+ }
+ }
+
+ printf(R"( r"""%s
+ Args:
+)",
+ comment.c_str());
+
+ // Inputs
+ for (int i = 0; i < op_def.input_arg_size(); ++i) {
+ const auto& arg(op_def.input_arg(i));
+ StringPiece description = op_def.input_arg(i).description();
+ string desc;
+ if (ConsumeEquals(&description)) { // Skip the generated type info.
+ desc = strings::StrCat(param_names[i], ": ");
+ } else {
+ desc = strings::StrCat(param_names[i], ": ",
+ ArgTypeName(op_def, arg, inferred_attrs, false));
+ }
+ if (!description.empty()) {
+ AppendWithinWidth(&desc, description, kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ // Attrs
+ for (const string& name : attrs) {
+ const auto& attr = *FindAttr(name, op_def);
+ string desc = strings::StrCat(AvoidPythonReserved(name), ": ");
+
+ static const char* const kAttrTypeName[][2] = {
+ {"string", "`string`"},
+ {"list(string)", "list of `strings`"},
+ {"int", "`int`"},
+ {"list(int)", "list of `ints`"},
+ {"float", "`float`"},
+ {"list(float)", "list of `floats`"},
+ {"bool", "`bool`"},
+ {"list(bool)", "list of `bools`"},
+ {"type", "`tf.DType`"},
+ {"list(type)", "list of `tf.DTypes`"},
+ {"shape", "`tf.TensorShape` or list of `ints`"},
+ {"list(shape)",
+ "list of shapes (each a `tf.TensorShape` or list of `ints`)"},
+ };
+ for (size_t i = 0; i < TF_ARRAYSIZE(kAttrTypeName); ++i) {
+ if (attr.type() == kAttrTypeName[i][0]) {
+ string s;
+ if (attr.has_default_value()) {
+ s = strings::StrCat("optional ", kAttrTypeName[i][1]);
+ } else {
+ s = kAttrTypeName[i][1];
+ }
+ if (s[0] == 'o' || (s[0] == '`' && (s[1] == 'i' || s[1] == 'o'))) {
+ strings::StrAppend(&desc, "An ", s);
+ } else {
+ strings::StrAppend(&desc, "A ", s);
+ }
+ break;
+ }
+ }
+
+ if (attr.has_allowed_values()) {
+ strings::StrAppend(&desc, " from: `",
+ AttrListToPython(attr.allowed_values()), "`");
+ }
+
+ if (attr.has_minimum()) {
+ if (attr.type() == "int") {
+ strings::StrAppend(&desc, " that is `>= ", attr.minimum(), "`");
+ } else if (attr.minimum() > 0) {
+ strings::StrAppend(&desc, " that has length `>= ", attr.minimum(), "`");
+ }
+ }
+
+ strings::StrAppend(&desc, ".");
+
+ if (attr.has_default_value()) {
+ strings::StrAppend(&desc, " Defaults to `",
+ AttrValueToPython(attr.type(), attr.default_value()),
+ "`.");
+ }
+
+ if (!attr.description().empty()) {
+ AppendWithinWidth(&desc, attr.description(),
+ kRightMargin - 4 /* indent */);
+ }
+ printf("%s", Indent(4, 6, desc).c_str());
+ }
+
+ printf(" name: A name for the operation (optional).\n");
+
+ std::vector<string> output_type_string;
+ output_type_string.reserve(op_def.output_arg_size());
+ for (int i = 0; i < op_def.output_arg_size(); ++i) {
+ output_type_string.push_back(
+ ArgTypeName(op_def, op_def.output_arg(i), inferred_attrs, true));
+ }
+ PrintReturns(op_def, output_type_string);
+
+ string return_prefix = strings::StrCat(" return _op_def_lib.apply_op(");
+ string return_args = strings::StrCat("\"", op_def.name(), "\", ");
+ for (size_t i = 0; i < param_names.size(); ++i) {
+ strings::StrAppend(&return_args, param_names[i], "=", param_names[i], ", ");
+ }
+ strings::StrAppend(&return_args, "name=name)");
+
+ printf(R"( """
+%s
+)",
+ // Wrap the arguments, and indent to the (.
+ WordWrap(return_prefix, return_args, kRightMargin).c_str());
+
+ printf("\n\n");
+}
+
+void GenerateLowerCaseOpName(const string& str, string* result) {
+ char joiner = '_';
+ int last_index = str.size() - 1;
+ for (int i = 0; i <= last_index; ++i) {
+ char c = str[i];
+ // Emit a joiner only if a previous-lower-to-now-upper or a
+ // now-upper-to-next-lower transition happens.
+ if (isupper(c) && (i > 0)) {
+ if (islower(str[i - 1]) || ((i < last_index) && islower(str[i + 1]))) {
+ result->push_back(joiner);
+ }
+ }
+ result->push_back(tolower(c));
+ }
+}
+
+} // namespace
+
+void PrintPythonOps(const OpList& ops, const string& hidden_ops,
+ bool require_shapes) {
+ // Header
+ // TODO(josh11b): Mention the library for which wrappers are being generated.
+ printf(R"("""Python wrappers around Brain.
+
+This file is MACHINE GENERATED! Do not edit.
+"""
+
+from google.protobuf import text_format
+
+from tensorflow.core.framework import op_def_pb2
+from tensorflow.python.framework import op_def_registry
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import op_def_library
+
+
+)");
+
+ std::vector<string> hidden_vec = str_util::Split(hidden_ops, ',');
+
+ // We'll make a copy of ops that filters out descriptions.
+ OpList cleaned_ops;
+ auto out = cleaned_ops.mutable_op();
+ out->Reserve(ops.op_size());
+ for (const auto& op_def : ops.op()) {
+ bool is_hidden = false;
+ for (const string& hidden : hidden_vec) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
+ }
+
+ // PrintPythonOp(op_def, is_hidden, op_def.name());
+ string lower_case_name;
+ GenerateLowerCaseOpName(op_def.name(), &lower_case_name);
+
+ // When users create custom python wrappers, they may link in the
+ // default op registry by accident, and because they can't
+ // enumerate all 'hidden' symbols, this guard is to prevent
+ // instantiating a python reserved word in their wrapper.
+ if (!is_hidden && IsPythonReserved(lower_case_name)) {
+ continue;
+ }
+
+ PrintPythonOp(op_def, is_hidden, lower_case_name);
+
+ if (!require_shapes) {
+ printf("ops.RegisterShape(\"%s\")(None)\n", op_def.name().c_str());
+ }
+
+ auto added = out->Add();
+ *added = op_def;
+ RemoveDescriptionsFromOpDef(added);
+ }
+
+ printf(R"(def _InitOpDefLibrary():
+ op_list = op_def_pb2.OpList()
+ text_format.Merge(_InitOpDefLibrary.op_list_ascii, op_list)
+ op_def_registry.register_op_list(op_list)
+ op_def_lib = op_def_library.OpDefLibrary()
+ op_def_lib.add_op_list(op_list)
+ return op_def_lib
+
+
+_InitOpDefLibrary.op_list_ascii = """%s"""
+
+
+_op_def_lib = _InitOpDefLibrary()
+)",
+ cleaned_ops.DebugString().c_str());
+}
+
+} // namespace tensorflow