From 8ecf1ebc5d83e66b29a07113b53c49ef8264703c Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Tue, 29 May 2018 00:50:34 -0700 Subject: Re-apply #18192. PiperOrigin-RevId: 198358055 --- tensorflow/java/src/gen/cc/java_defs.h | 80 ++-- tensorflow/java/src/gen/cc/op_gen_main.cc | 46 +-- tensorflow/java/src/gen/cc/op_generator.cc | 464 +++++++++++++++++++++-- tensorflow/java/src/gen/cc/op_generator.h | 37 +- tensorflow/java/src/gen/cc/op_specs.cc | 423 +++++++++++++++++++++ tensorflow/java/src/gen/cc/op_specs.h | 173 +++++++++ tensorflow/java/src/gen/cc/source_writer.cc | 142 ++++--- tensorflow/java/src/gen/cc/source_writer.h | 55 ++- tensorflow/java/src/gen/cc/source_writer_test.cc | 155 ++++---- tensorflow/java/src/gen/gen_ops.bzl | 74 ++-- 10 files changed, 1348 insertions(+), 301 deletions(-) create mode 100644 tensorflow/java/src/gen/cc/op_specs.cc create mode 100644 tensorflow/java/src/gen/cc/op_specs.h (limited to 'tensorflow/java/src') diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 59f8beaee7..d9d6f8adc8 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -16,19 +16,22 @@ limitations under the License. #ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ -#include #include +#include +#include +#include namespace tensorflow { namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { - PUBLIC = (1 << 0), + PACKAGE = 0, + PUBLIC = (1 << 0), PROTECTED = (1 << 1), - PRIVATE = (1 << 2), - STATIC = (1 << 3), - FINAL = (1 << 4), + PRIVATE = (1 << 2), + STATIC = (1 << 3), + FINAL = (1 << 4), }; class Annotation; @@ -72,6 +75,8 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } + static Type Generic(const string& name) { return Type(Type::GENERIC, name); } + static Type Wildcard() { return Type(Type::GENERIC, ""); } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -81,9 +86,6 @@ class Type { static Type Enum(const string& name, const string& package = "") { return Type(Type::ENUM, name, package); } - static Type Generic(const string& name = "") { - return Type(Type::GENERIC, name); - } static Type ClassOf(const Type& type) { return Class("Class").add_parameter(type); } @@ -96,11 +98,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string& description() const { return description_; } - Type& description(const string& description) { - description_ = description; - return *this; + const string canonical_name() const { + return package_.empty() ? name_ : package_ + "." + name_; } + bool wildcard() const { return name_.empty(); } // only wildcards has no name const std::list& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); @@ -120,14 +121,6 @@ class Type { } return *this; } - // Returns true if "type" is of a known collection type (only a few for now) - bool IsCollection() const { - return name_ == "List" || name_ == "Iterable"; - } - // Returns true if this instance is a wildcard () - bool IsWildcard() const { - return kind_ == GENERIC && name_.empty(); - } protected: Type(Kind kind, const string& name, const string& package = "") @@ -137,7 +130,6 @@ class Type { Kind kind_; string name_; string package_; - string description_; std::list parameters_; std::list annotations_; std::list supertypes_; @@ -180,16 +172,11 @@ class Variable { const string& name() const { return name_; } const Type& type() const { return type_; } bool variadic() const { return variadic_; } - const string& description() const { return description_; } - Variable& description(const string& description) { - description_ = description; - return *this; - } + private: string name_; Type type_; bool variadic_; - string description_; Variable(const string& name, const Type& type, bool variadic) : name_(name), type_(type), variadic_(variadic) {} @@ -210,16 +197,6 @@ class Method { bool constructor() const { return constructor_; } const string& name() const { return name_; } const Type& return_type() const { return return_type_; } - const string& description() const { return description_; } - Method& description(const string& description) { - description_ = description; - return *this; - } - const string& return_description() const { return return_description_; } - Method& return_description(const string& description) { - return_description_ = description; - return *this; - } const std::list& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); @@ -235,8 +212,6 @@ class Method { string name_; Type return_type_; bool constructor_; - string description_; - string return_description_; std::list arguments_; std::list annotations_; @@ -244,6 +219,33 @@ class Method { : name_(name), return_type_(return_type), constructor_(constructor) {} }; +// A definition of a documentation bloc for a Java element (JavaDoc) +class Javadoc { + public: + static Javadoc Create(const string& brief = "") { return Javadoc(brief); } + const string& brief() const { return brief_; } + const string& details() const { return details_; } + Javadoc& details(const string& details) { + details_ = details; + return *this; + } + const std::list>& tags() const { return tags_; } + Javadoc& add_tag(const string& tag, const string& text) { + tags_.push_back(std::make_pair(tag, text)); + return *this; + } + Javadoc& add_param_tag(const string& name, const string& text) { + return add_tag("param", name + " " + text); + } + + private: + string brief_; + string details_; + std::list> tags_; + + explicit Javadoc(const string& brief) : brief_(brief) {} +}; + } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index bea99f3d7f..0d9e0883af 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -36,49 +36,43 @@ const char kUsageHeader[] = "Operation wrappers are generated under the path specified by the " "'--output_dir' argument. This path can be absolute or relative to the\n" "current working directory and will be created if it does not exists.\n\n" - "The '--lib_name' argument is used to classify the set of operations. If " - "the chosen name contains more than one word, it must be provided in \n" - "snake_case. This value is declined into other meaningful names, such as " - "the group and package of the generated operations. For example,\n" - "'--lib_name=my_lib' generates the operations under the " - "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" - "group.\n\n" - "Note that the operator group assigned to the generated wrappers is just " - "an annotation tag at this stage. Operations will not be available " - "through\n" - "the 'org.tensorflow.op.Ops' API as a group until the generated classes " - "are compiled using an appropriate annotation processor.\n\n" - "Finally, the '--base_package' overrides the default parent package " - "under which the generated subpackage and classes are to be located.\n\n"; + "Note that the operations will not be available through the " + "'org.tensorflow.op.Ops' API until the generated classes are compiled\n" + "using an appropriate annotation processor.\n\n" + "The '--base_package' overrides the default parent package under which " + "the generated subpackage and classes are to be located.\n\n" + "Finally, the `--api_dirs` argument takes a list of comma-separated " + "directories of API definitions can be provided to override default\n" + "values found in the ops definitions. Directories are ordered by priority " + "(the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow int main(int argc, char* argv[]) { - tensorflow::string lib_name; tensorflow::string output_dir; tensorflow::string base_package = "org.tensorflow.op"; + tensorflow::string api_dirs_str; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, "Root directory into which output files are generated"), - tensorflow::Flag( - "lib_name", &lib_name, - "A name, in snake_case, used to classify this set of operations"), tensorflow::Flag( "base_package", &base_package, - "Package parent to the generated subpackage and classes")}; + "Package parent to the generated subpackage and classes"), + tensorflow::Flag( + "api_dirs", &api_dirs_str, + "List of directories that contains the ops api definitions")}; tensorflow::string usage = tensorflow::java::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - - tensorflow::java::OpGenerator generator; + QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; + std::vector api_dirs = tensorflow::str_util::Split( + api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); + tensorflow::java::OpGenerator generator(api_dirs); tensorflow::OpList ops; - tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = - generator.Run(ops, lib_name, base_package, output_dir); - TF_QCHECK_OK(status); + tensorflow::OpRegistry::Global()->Export(false, &ops); + TF_CHECK_OK(generator.Run(ops, base_package, output_dir)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index def06baf2d..debd95fc62 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -13,54 +13,466 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include +#include #include +#include +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" #include "tensorflow/java/src/gen/cc/op_generator.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" namespace tensorflow { namespace java { namespace { -string CamelCase(const string& str, char delimiter, bool upper) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == delimiter) { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; +const char* kLicense = + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/" + "\n"; + +// There is three different modes to render an op class, depending on the +// number and type of outputs it has: +// +// DEFAULT: This mode does not provide any specialization for the op class, it +// is applied when the operation does not comply with any other mode +// +// OPERAND: The op class implements the Operand interface, allowing an +// instance to be passed directly in input to another operation +// +// LIST_OPERAND: The op class implements the Iterable> interface, +// allowing an instance to be passed directly as a list input to +// another operation +// +enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND }; + +void AddArgument(const Variable& var, const string& description, + Method* method_out, Javadoc* javadoc_out) { + method_out->add_argument(var); + javadoc_out->add_param_tag(var.name(), description); +} + +void CollectOpDependencies(const OpSpec& op, RenderMode mode, + std::list* out) { + out->push_back(Type::Class("Operation", "org.tensorflow")); + out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); + out->push_back(Type::Class("Scope", "org.tensorflow.op")); + if (mode == OPERAND) { + out->push_back(Type::Class("Output", "org.tensorflow")); + } else if (mode == LIST_OPERAND) { + out->push_back(Type::Interface("Iterator", "java.util")); + } + // Don't pay attention to duplicate types in the dependency list, they will + // be filtered out by the SourceWriter. + for (const ArgumentSpec& input : op.inputs()) { + out->push_back(input.var().type()); + if (input.iterable()) { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } + } + for (const ArgumentSpec& output : op.outputs()) { + out->push_back(output.var().type()); + if (output.iterable()) { + out->push_back(Type::Class("Arrays", "java.util")); + } + } + for (const AttributeSpec& attribute : op.attributes()) { + out->push_back(attribute.var().type()); + out->push_back(attribute.jni_type()); + } + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + out->push_back(optional_attribute.var().type()); + } +} + +void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, + SourceWriter* writer) { + string var_name = optional ? "opts." + attr.var().name() : attr.var().name(); + if (attr.iterable()) { + string array_name = attr.var().name() + "Array"; + writer->AppendType(attr.jni_type()) + .Append("[] " + array_name + " = new ") + .AppendType(attr.jni_type()) + .Append("[" + var_name + ".size()];") + .EndLine() + .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)") + .Append(array_name + "[i] = "); + if (attr.type().kind() == Type::GENERIC) { + writer->Append("DataType.fromClass(" + var_name + ".get(i));"); + } else { + writer->Append(var_name + ".get(i);"); + } + writer->EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ") + .Append(array_name + ");") + .EndLine(); + } else { + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); + if (attr.var().type().name() == "Class") { + writer->Append("DataType.fromClass(" + var_name + "));"); } else { - result += c; + writer->Append(var_name + ");"); } + writer->EndLine(); } - return result; } -} // namespace +void RenderFactoryMethods(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method factory = Method::Create("create", op_class); + Javadoc factory_doc = + Javadoc::Create("Factory method to create a class to wrap a new " + + op_class.name() + " operation to the graph."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + AddArgument(scope, "current graph scope", &factory, &factory_doc); + for (const ArgumentSpec& input : op.inputs()) { + AddArgument(input.var(), input.description(), &factory, &factory_doc); + } + for (const AttributeSpec& attr : op.attributes()) { + AddArgument(attr.var(), attr.description(), &factory, &factory_doc); + } + if (!op.optional_attributes().empty()) { + AddArgument(Variable::Varargs("options", Type::Class("Options")), + "carries optional attributes values", &factory, &factory_doc); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); -OpGenerator::OpGenerator() : env(Env::Default()) {} + writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_op_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); + writer->EndLine(); + for (const ArgumentSpec& input : op.inputs()) { + if (input.iterable()) { + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); + writer->EndLine(); + } else { + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); + writer->EndLine(); + } + } + for (const AttributeSpec& attribute : op.attributes()) { + WriteSetAttrDirective(attribute, false, writer); + } + if (!op.optional_attributes().empty()) { + writer->BeginBlock("if (options != null)") + .BeginBlock("for (Options opts : options)"); + for (const AttributeSpec& attribute : op.optional_attributes()) { + writer->BeginBlock("if (opts." + attribute.var().name() + " != null)"); + WriteSetAttrDirective(attribute, true, writer); + writer->EndBlock(); + } + writer->EndBlock().EndBlock(); + } + writer->Append("return new ") + .AppendType(op_class) + .Append("(opBuilder.build());") + .EndLine(); + writer->EndMethod(); +} -OpGenerator::~OpGenerator() {} +void RenderConstructor(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Variable operation = + Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); + Method constructor = Method::ConstructorFor(op_class).add_argument(operation); + for (const ArgumentSpec& output : op.outputs()) { + if (output.iterable() && !output.type().wildcard()) { + constructor.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + break; + } + } + writer->BeginMethod(constructor, PRIVATE) + .Append("super(operation);") + .EndLine(); + if (!op.outputs().empty()) { + writer->Append("int outputIdx = 0;").EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + if (output.iterable()) { + string var_length = output.var().name() + "Length"; + writer->Append("int " + var_length) + .Append(" = operation.outputListLength(\"" + output.op_def_name() + + "\");") + .EndLine() + .Append(output.var().name() + " = Arrays.asList("); + if (!output.type().wildcard()) { + writer->Append("(") + .AppendType(output.var().type().parameters().front()) + .Append("[])"); + } + writer->Append("operation.outputList(outputIdx, " + var_length + "));") + .EndLine() + .Append("outputIdx += " + var_length + ";") + .EndLine(); + } else { + writer + ->Append(output.var().name() + " = operation.output(outputIdx++);") + .EndLine(); + } + } + } + writer->EndMethod(); +} -Status OpGenerator::Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir) { - const string package = - base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); - const string package_path = - output_dir + '/' + str_util::StringReplace(package, ".", "/", true); - const string group = CamelCase(lib_name, '_', false); +void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { + for (const AttributeSpec& attr : op.optional_attributes()) { + Method setter = Method::Create(attr.var().name(), Type::Class("Options")); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); + writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc) + .Append("return new Options()." + attr.var().name() + "(" + + attr.var().name() + ");") + .EndLine() + .EndMethod(); + } + for (const ArgumentSpec& output : op.outputs()) { + Method getter = Method::Create(output.var().name(), output.var().type()); + Javadoc getter_doc = Javadoc::Create(output.description()); + writer->BeginMethod(getter, PUBLIC, &getter_doc) + .Append("return " + output.var().name() + ";") + .EndLine() + .EndMethod(); + } +} - if (!env->FileExists(package_path).ok()) { - TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); +void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, + SourceWriter* writer) { + ArgumentSpec output = op.outputs().front(); + + if (mode == OPERAND) { + bool cast2obj = output.type().wildcard(); + Type return_type = + Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.type()); + Method as_output = Method::Create("asOutput", return_type) + .add_annotation(Annotation::Create("Override")); + if (cast2obj) { + as_output.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + } + writer->BeginMethod(as_output, PUBLIC); + if (cast2obj) { + writer->Append("return (").AppendType(return_type).Append(") "); + } else { + writer->Append("return "); + } + writer->Append(output.var().name() + ";").EndLine().EndMethod(); + + } else if (mode == LIST_OPERAND) { + Type operand = Type::Interface("Operand", "org.tensorflow"); + if (output.type().wildcard()) { + operand.add_parameter(Type::Class("Object")); + } else { + operand.add_parameter(output.type()); + } + Type return_type = + Type::Interface("Iterator", "java.util").add_parameter(operand); + Method iterator = + Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); + // cast the output list using a raw List + writer->BeginMethod(iterator, PUBLIC) + .Append("return (" + return_type.name() + ") ") + .Append(output.var().name() + ".iterator();") + .EndLine() + .EndMethod(); + } +} + +void RenderOptionsClass(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Type options_class = Type::Class("Options"); + Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " + + op_class.canonical_name() + "}"); + writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); + for (const AttributeSpec& attr : op.optional_attributes()) { + Method setter = Method::Create(attr.var().name(), options_class); + Javadoc setter_doc = Javadoc::Create(); + AddArgument(attr.var(), attr.description(), &setter, &setter_doc); + writer->BeginMethod(setter, PUBLIC, &setter_doc) + .Append("this." + attr.var().name() + " = " + attr.var().name() + ";") + .EndLine() + .Append("return this;") + .EndLine() + .EndMethod(); + } + writer->EndLine(); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + writer->WriteField(optional_attribute.var(), PRIVATE); } + Method constructor = Method::ConstructorFor(options_class); + writer->BeginMethod(constructor, PRIVATE).EndMethod(); + writer->EndType(); +} + +inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { + return Type::Class( + endpoint.name(), + base_package + "." + str_util::Lowercase(endpoint.package())); +} - LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; - // TODO(karllessard) generate wrappers from list of ops +void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, + const string& base_package, const string& output_dir, + Env* env) { + Type op_class( + ClassOf(endpoint, base_package) + .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); + Javadoc op_javadoc(endpoint.javadoc()); + // op interfaces + RenderMode mode = DEFAULT; + if (op.outputs().size() == 1) { + const ArgumentSpec& output = op.outputs().front(); + Type operand_type(output.type().wildcard() ? Type::Class("Object") + : output.type()); + Type operand_inf(Type::Interface("Operand", "org.tensorflow") + .add_parameter(operand_type)); + if (output.iterable()) { + mode = LIST_OPERAND; + op_class.add_supertype(Type::IterableOf(operand_inf)); + } else { + mode = OPERAND; + op_class.add_supertype(operand_inf); + } + } + // op generic parameters + std::set generics; + for (const ArgumentSpec& output : op.outputs()) { + if (output.type().kind() == Type::GENERIC && !output.type().wildcard() && + generics.find(output.type().name()) == generics.end()) { + op_class.add_parameter(output.type()); + op_javadoc.add_param_tag( + "<" + output.type().name() + ">", + "data type for {@code " + output.var().name() + "()} output"); + generics.insert(output.type().name()); + } + } + // op annotations + op_class.add_annotation( + Annotation::Create("Generated", "javax.annotation") + .attributes("value = \"TensorFlow Java Op Generator\"")); + if (endpoint.deprecated()) { + op_class.add_annotation(Annotation::Create("Deprecated")); + string explanation; + if (!op.endpoints().front().deprecated()) { + explanation = + "use {@link " + + ClassOf(op.endpoints().front(), base_package).canonical_name() + + "} instead"; + } else { + explanation = op.deprecation_explanation(); + } + op_javadoc.add_tag("deprecated", explanation); + } + if (!op.hidden()) { + // expose the op in the Ops Graph API only if it is visible + op_class.add_annotation( + Annotation::Create("Operator", "org.tensorflow.op.annotation") + .attributes("group = \"" + endpoint.package() + "\"")); + } + // create op class file + const string op_dir_name = io::JoinPath( + output_dir, str_util::StringReplace(op_class.package(), ".", "/", true)); + if (!env->FileExists(op_dir_name).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name)) + << op_dir_name; + } + const string op_file_name = op_class.name() + ".java"; + std::unique_ptr op_file; + TF_CHECK_OK( + env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file)) + << op_file_name; + + // render endpoint source code + SourceFileWriter writer(op_file.get()); + std::list dependencies; + CollectOpDependencies(op, mode, &dependencies); + writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL, + &dependencies, &op_javadoc); + if (!op.optional_attributes().empty()) { + RenderOptionsClass(op, op_class, &writer); + } + RenderFactoryMethods(op, op_class, &writer); + RenderGettersAndSetters(op, &writer); + if (mode != DEFAULT) { + RenderInterfaceImpl(op, mode, &writer); + } + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); + } + RenderConstructor(op, op_class, &writer); + writer.EndType(); +} + +bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) { + if (api_def.visibility() == ApiDef::SKIP) { + return false; + } + for (const auto& attr : op_def.attr()) { + if (attr.type() == "func") { + return false; // TODO(karllessard) add support for function attributes + } + } + return true; +} + +} // namespace + +Status OpGenerator::Run(const OpList& op_list, const string& base_package, + const string& output_dir) { + ApiDefMap api_map(op_list); + if (!api_dirs_.empty()) { + // Only load api files that correspond to the requested "op_list" + for (const auto& op : op_list.op()) { + for (const auto& api_def_dir : api_dirs_) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env_->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)) + << api_def_file_pattern; + } + } + } + } + api_map.UpdateDocs(); + for (const auto& op_def : op_list.op()) { + const ApiDef* api_def = api_map.GetApiDef(op_def.name()); + if (CanGenerateOp(op_def, *api_def)) { + OpSpec op(OpSpec::Create(op_def, *api_def)); + for (const EndpointSpec& endpoint : op.endpoints()) { + GenerateOp(op, endpoint, base_package, output_dir, env_); + } + } + } return Status::OK(); } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 4b55ed3ed9..05decd6b54 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -17,34 +17,39 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ #include +#include -#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { -/// \brief A generator of Java operation wrappers. -/// -/// Such generator is normally ran only once per executable, outputting -/// wrappers for the all registered operations it has been compiled with. -/// Nonetheless, it is designed to support multiple runs, giving a different -/// list of operations on each cycle. +// A generator of Java operation wrappers. +// +// This generator takes a list of ops definitions in input and outputs +// a Java Op wrapper for each of them in the provided directory. The same +// generator instance can be invoked multiple times with a different list of +// ops definitions. class OpGenerator { public: - OpGenerator(); - virtual ~OpGenerator(); + explicit OpGenerator(const std::vector& api_dirs, + Env* env = Env::Default()) + : api_dirs_(api_dirs), env_(env) {} - /// \brief Generates wrappers for the given list of 'ops'. - /// - /// Output files are generated in //, - /// where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir); + // Generates wrappers for the given list of 'ops'. + // + // Output files are generated in //, + // where 'op_package' is derived from ops endpoints. + Status Run(const OpList& op_list, const string& base_package, + const string& output_dir); private: - Env* env; + const std::vector api_dirs_; + Env* env_; }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc new file mode 100644 index 0000000000..181fd4c5e3 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -0,0 +1,423 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "re2/re2.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" + +namespace tensorflow { +namespace java { +namespace { + +inline bool IsRealNumbers(const AttrValue& values) { + if (!values.has_list()) { + return RealNumberTypes().Contains(values.type()); + } + for (int i = 0; i < values.list().type_size(); ++i) { + if (!RealNumberTypes().Contains(values.list().type(i))) { + return false; + } + } + return true; +} + +class TypeResolver { + public: + explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} + + // Returns the class type of an input/output argument + // + // For example, if the argument's datatype is DT_STRING, this method will + // return "java.lang.String", so the argument can become "Operand" + // in the Ops API + Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out); + + // Returns types of an input attribute + // + // The first element of the pair is the class type of this attribute while + // the second is its JNI/primitive type equivalent, required for explicit + // unboxing. + // + // For example, if the attribute is of type "float", this method will return + // , so the attribute can be used as a "Float" object + // in the Ops API and casted to a "float" when passing through the JNI layer. + std::pair TypesOf(const OpDef_AttrDef& attr_def, + bool* iterable_out); + + // Returns true if the type of this attribute has already been resolved + bool IsAttributeVisited(const string& attr_name) { + return visited_attrs_.find(attr_name) != visited_attrs_.cend(); + } + + private: + const OpDef op_def_; + std::map visited_attrs_; + char next_generic_letter_ = 'T'; + + std::pair MakeTypePair(const Type& type, const Type& jni_type) { + return std::make_pair(type, jni_type); + } + std::pair MakeTypePair(const Type& type) { + return std::make_pair(type, type); + } + Type NextGeneric() { + char generic_letter = next_generic_letter_++; + if (next_generic_letter_ > 'Z') { + next_generic_letter_ = 'A'; + } + return Type::Generic(string(1, generic_letter)); + } +}; + +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) { + *iterable_out = false; + if (!arg_def.number_attr().empty()) { + // when number_attr is set, argument has to be a list of tensors + *iterable_out = true; + visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int())); + } + Type type = Type::Wildcard(); + if (arg_def.type() != DataType::DT_INVALID) { + // resolve type from DataType + switch (arg_def.type()) { + case DataType::DT_BOOL: + type = Type::Class("Boolean"); + break; + case DataType::DT_STRING: + type = Type::Class("String"); + break; + case DataType::DT_FLOAT: + type = Type::Class("Float"); + break; + case DataType::DT_DOUBLE: + type = Type::Class("Double"); + break; + case DataType::DT_UINT8: + type = Type::Class("UInt8", "org.tensorflow.types"); + break; + case DataType::DT_INT32: + type = Type::Class("Integer"); + break; + case DataType::DT_INT64: + type = Type::Class("Long"); + break; + case DataType::DT_RESOURCE: + // TODO(karllessard) create a Resource utility class that could be + // used to store a resource and its type (passed in a second argument). + // For now, we need to force a wildcard and we will unfortunately lose + // track of the resource type. + break; + default: + // Any other datatypes does not have a equivalent in Java and must + // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...) + break; + } + } else if (!arg_def.type_attr().empty()) { + // resolve type from attribute (if already visited, retrieve its type) + if (IsAttributeVisited(arg_def.type_attr())) { + type = visited_attrs_.at(arg_def.type_attr()); + } else { + for (const auto& attr_def : op_def_.attr()) { + if (attr_def.name() == arg_def.type_attr()) { + type = TypesOf(attr_def, iterable_out).first; + break; + } + } + } + } else if (!arg_def.type_list_attr().empty()) { + // type is a list of tensors that can be of different data types, so leave + // it as a list of wildcards + *iterable_out = true; + visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type)); + + } else { + LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name() + << "\" in operation \"" << op_def_.name() << "\""; + } + return type; +} + +std::pair TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, + bool* iterable_out) { + std::pair types = MakeTypePair(Type::Wildcard()); + *iterable_out = false; + StringPiece attr_type = attr_def.type(); + if (str_util::ConsumePrefix(&attr_type, "list(")) { + attr_type.remove_suffix(1); // remove closing brace + *iterable_out = true; + } + if (attr_type == "string") { + types = MakeTypePair(Type::Class("String")); + + } else if (attr_type == "int") { + types = MakeTypePair(Type::Class("Long"), Type::Long()); + + } else if (attr_type == "float") { + types = MakeTypePair(Type::Class("Float"), Type::Float()); + + } else if (attr_type == "bool") { + types = MakeTypePair(Type::Class("Boolean"), Type::Boolean()); + + } else if (attr_type == "shape") { + types = MakeTypePair(Type::Class("Shape", "org.tensorflow")); + + } else if (attr_type == "tensor") { + types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") + .add_parameter(Type::Wildcard())); + + } else if (attr_type == "type") { + Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); + if (IsRealNumbers(attr_def.allowed_values())) { + type.add_supertype(Type::Class("Number")); + } + types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow")); + + } else { + LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type + << "\" in operation \"" << op_def_.name() << "\""; + } + visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); + return types; +} + +string SnakeToCamelCase(const string& str, bool upper = false) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == '_') { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +bool FindAndCut(string* input, const RE2& expr, string* before_match, + string* ret_match = nullptr) { + string match; + if (!RE2::PartialMatch(*input, expr, &match)) return false; + *before_match = input->substr(0, input->find(match)); + *input = input->substr(before_match->size() + match.size()); + if (ret_match != nullptr) *ret_match = match; + return true; +} + +string ParseDocumentation(const string& inp) { + std::stringstream javadoc_text; + + // TODO(karllessard) This is a very minimalist utility method for converting + // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check + // for alternatives to increase the level of support for markups. + std::vector markups_subexpr; + markups_subexpr.push_back("\n+\\*\\s+"); // lists + markups_subexpr.push_back("\n{2,}"); // paragraphs + markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks + markups_subexpr.push_back("`+"); // inlined code and code blocks + markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis + markups_subexpr.push_back("\\["); // hyperlinks + const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")"); + + bool in_list = false; + string input = inp; + while (true) { + string text, markup; + if (!FindAndCut(&input, markup_expr, &text, &markup)) { + javadoc_text << input; + break; // end of loop + } + javadoc_text << text; + if (str_util::StartsWith(markup, "\n")) { + javadoc_text << "\n"; + if (str_util::StrContains(markup, "*")) { + // new list item + javadoc_text << (in_list ? "\n" : "
    \n") << "
  • \n"; + in_list = true; + } else if (in_list) { + // end of list + javadoc_text << "
  • \n
\n"; + in_list = false; + } else if (!str_util::StartsWith(input, "```")) { + // new paragraph (not required if a
 block follows)
+        javadoc_text << "

\n"; + } + } else if (str_util::StartsWith(markup, "```")) { + // code blocks + if (FindAndCut(&input, "(```\\s*\n*)", &text)) { + javadoc_text << "

{@code\n" << text << "}
\n"; + } else { + javadoc_text << markup; + } + } else if (str_util::StartsWith("(" + markup + ")", "`")) { + // inlined code + if (FindAndCut(&input, markup, &text)) { + javadoc_text << "{@code " << text << "}"; + } else { + javadoc_text << markup; + } + } else if (markup == "**") { + // text emphasis (strong) + if (FindAndCut(&input, "(\\b\\*{2})", &text)) { + javadoc_text << "" << ParseDocumentation(text) << ""; + } else { + javadoc_text << markup; + } + } else if (markup == "*") { + // text emphasis (normal) + if (FindAndCut(&input, "(\\b\\*{1})", &text)) { + javadoc_text << "" << ParseDocumentation(text) << ""; + } else { + javadoc_text << markup; + } + } else if (str_util::StartsWith(markup, "[")) { + // hyperlinks + string label; + string link; + if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label, + &link) && + str_util::StartsWith(input, label + link)) { + input = input.substr(label.size() + link.size()); + javadoc_text << "" + << ParseDocumentation(label) << ""; + } else { + javadoc_text << markup; + } + } else { + // safe fallback + javadoc_text << markup; + } + } + return javadoc_text.str(); +} + +ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, + const ApiDef::Arg& input_api_def, + TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(input_def, &iterable); + Type var_type = + Type::Interface("Operand", "org.tensorflow").add_parameter(type); + if (iterable) { + var_type = Type::IterableOf(var_type); + } + return ArgumentSpec( + input_api_def.name(), + Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), + type, ParseDocumentation(input_api_def.description()), iterable); +} + +AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, + const ApiDef::Attr& attr_api_def, + TypeResolver* type_resolver) { + bool iterable = false; + std::pair types = type_resolver->TypesOf(attr_def, &iterable); + Type var_type = types.first.kind() == Type::GENERIC + ? Type::Class("Class").add_parameter(types.first) + : types.first; + if (iterable) { + var_type = Type::ListOf(var_type); + } + return AttributeSpec( + attr_api_def.name(), + Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), + types.first, types.second, ParseDocumentation(attr_api_def.description()), + iterable, attr_api_def.has_default_value()); +} + +ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, + const ApiDef::Arg& output_api, + TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(output_def, &iterable); + Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type); + if (iterable) { + var_type = Type::ListOf(var_type); + } + return ArgumentSpec( + output_api.name(), + Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), + type, ParseDocumentation(output_api.description()), iterable); +} + +EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, + const ApiDef_Endpoint& endpoint_def) { + std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); + string package; + string name; + if (name_tokens.size() > 1) { + package = name_tokens.at(0); + name = name_tokens.at(1); + } else { + package = "core"; // generate unclassified ops in the 'core' package + name = name_tokens.at(0); + } + return EndpointSpec(package, name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description()))); +} + +} // namespace + +OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { + OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); + TypeResolver type_resolver(op_def); + for (const string& next_input_name : api_def.arg_order()) { + for (int i = 0; i < op_def.input_arg().size(); ++i) { + if (op_def.input_arg(i).name() == next_input_name) { + op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), + &type_resolver)); + break; + } + } + } + for (int i = 0; i < op_def.attr().size(); ++i) { + // do not parse attributes already visited, they have probably been inferred + // before as an input argument type + if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { + AttributeSpec attr = + CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); + // attributes with a default value are optional + if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { + op.optional_attributes_.push_back(attr); + } else { + op.attributes_.push_back(attr); + } + } + } + for (int i = 0; i < op_def.output_arg().size(); ++i) { + op.outputs_.push_back( + CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver)); + } + for (const auto& endpoint_def : api_def.endpoint()) { + op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); + } + return op; +} + +} // namespace java +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h new file mode 100644 index 0000000000..ca0ba16745 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -0,0 +1,173 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ + +#include +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" + +namespace tensorflow { +namespace java { + +class EndpointSpec { + public: + // A specification for an operation endpoint + // + // package: package of this endpoint (from which also derives its package) + // name: name of this endpoint class + // javadoc: the endpoint class documentation + // TODO(annarev): hardcode depcreated to false until deprecated is possible + EndpointSpec(const string& package, const string& name, + const Javadoc& javadoc) + : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {} + + const string& package() const { return package_; } + const string& name() const { return name_; } + const Javadoc& javadoc() const { return javadoc_; } + bool deprecated() const { return deprecated_; } + + private: + const string package_; + const string name_; + const Javadoc javadoc_; + const bool deprecated_; +}; + +class ArgumentSpec { + public: + // A specification for an operation argument + // + // op_def_name: argument name, as known by TensorFlow core + // var: a variable to represent this argument in Java + // type: the tensor type of this argument + // description: a description of this argument, in javadoc + // iterable: true if this argument is a list + ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type, + const string& description, bool iterable) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable) {} + + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } + + private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; +}; + +class AttributeSpec { + public: + // A specification for an operation attribute + // + // op_def_name: attribute name, as known by TensorFlow core + // var: a variable to represent this attribute in Java + // type: the type of this attribute + // jni_type: the type of this attribute in JNI layer (see OperationBuilder) + // description: a description of this attribute, in javadoc + // iterable: true if this attribute is a list + // has_default_value: true if this attribute has a default value if not set + AttributeSpec(const string& op_def_name, const Variable& var, + const Type& type, const Type& jni_type, + const string& description, bool iterable, + bool has_default_value) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable), + jni_type_(jni_type), + has_default_value_(has_default_value) {} + + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } + const Type& jni_type() const { return jni_type_; } + bool has_default_value() const { return has_default_value_; } + + private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; + const Type jni_type_; + const bool has_default_value_; +}; + +class OpSpec { + public: + // Parses an op definition and its API to produce a specification used for + // rendering its Java wrapper + // + // op_def: Op definition + // api_def: Op API definition + static OpSpec Create(const OpDef& op_def, const ApiDef& api_def); + + const string& graph_op_name() const { return graph_op_name_; } + bool hidden() const { return hidden_; } + const string& deprecation_explanation() const { + return deprecation_explanation_; + } + const std::vector endpoints() const { return endpoints_; } + const std::vector& inputs() const { return inputs_; } + const std::vector& outputs() const { return outputs_; } + const std::vector& attributes() const { return attributes_; } + const std::vector& optional_attributes() const { + return optional_attributes_; + } + + private: + // A specification for an operation + // + // graph_op_name: name of this op, as known by TensorFlow core engine + // hidden: true if this op should not be visible through the Graph Ops API + // deprecation_explanation: message to show if all endpoints are deprecated + explicit OpSpec(const string& graph_op_name, bool hidden, + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), + hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} + + const string graph_op_name_; + const bool hidden_; + const string deprecation_explanation_; + std::vector endpoints_; + std::vector inputs_; + std::vector outputs_; + std::vector attributes_; + std::vector optional_attributes_; +}; + +} // namespace java +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index a02f75ad6e..66401bdba7 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include -#include +#include +#include #include "tensorflow/java/src/gen/cc/source_writer.h" @@ -83,20 +83,22 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { } SourceWriter& SourceWriter::AppendType(const Type& type) { - if (type.kind() == Type::Kind::GENERIC && type.name().empty()) { + if (type.wildcard()) { Append("?"); } else { Append(type.name()); - } - if (!type.parameters().empty()) { - Append("<"); - for (const Type& t : type.parameters()) { - if (&t != &type.parameters().front()) { - Append(", "); + if (!type.parameters().empty()) { + Append("<"); + bool first = true; + for (const Type& t : type.parameters()) { + if (!first) { + Append(", "); + } + AppendType(t); + first = false; } - AppendType(t); + Append(">"); } - Append(">"); } return *this; } @@ -107,7 +109,21 @@ SourceWriter& SourceWriter::EndLine() { return *this; } -SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { +SourceWriter& SourceWriter::BeginBlock(const string& expression) { + if (!expression.empty()) { + Append(expression + " {"); + } else { + Append(newline_ ? "{" : " {"); + } + return EndLine().Indent(2); +} + +SourceWriter& SourceWriter::EndBlock() { + return Indent(-2).Append("}").EndLine(); +} + +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); if (!method.constructor()) { generic_namespace->Visit(method.return_type()); @@ -116,8 +132,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { generic_namespace->Visit(v.type()); } EndLine(); - WriteDoc(method.description(), method.return_description(), - &method.arguments()); + if (javadoc != nullptr) { + WriteJavadoc(*javadoc); + } if (!method.annotations().empty()) { WriteAnnotations(method.annotations()); } @@ -130,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { AppendType(method.return_type()).Append(" "); } Append(method.name()).Append("("); + bool first = true; for (const Variable& v : method.arguments()) { - if (&v != &method.arguments().front()) { + if (!first) { Append(", "); } AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name()); + first = false; } return Append(")").BeginBlock(); } @@ -145,29 +164,36 @@ SourceWriter& SourceWriter::EndMethod() { return *this; } -SourceWriter& SourceWriter::BeginType(const Type& type, - const std::list* dependencies, int modifiers) { +SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers, + const std::list* extra_dependencies, + const Javadoc* javadoc) { if (!type.package().empty()) { Append("package ").Append(type.package()).Append(";").EndLine(); } - if (dependencies != nullptr && !dependencies->empty()) { - TypeImporter type_importer(type.package()); - for (const Type& t : *dependencies) { + TypeImporter type_importer(type.package()); + type_importer.Visit(type); + if (extra_dependencies != nullptr) { + for (const Type& t : *extra_dependencies) { type_importer.Visit(t); } + } + if (!type_importer.imports().empty()) { EndLine(); for (const string& s : type_importer.imports()) { Append("import ").Append(s).Append(";").EndLine(); } } - return BeginInnerType(type, modifiers); + return BeginInnerType(type, modifiers, javadoc); } -SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) { +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); generic_namespace->Visit(type); EndLine(); - WriteDoc(type.description()); + if (javadoc != nullptr) { + WriteJavadoc(*javadoc); + } if (!type.annotations().empty()) { WriteAnnotations(type.annotations()); } @@ -200,14 +226,15 @@ SourceWriter& SourceWriter::EndType() { return *this; } -SourceWriter& SourceWriter::WriteFields(const std::list& fields, - int modifiers) { - EndLine(); - for (const Variable& v : fields) { - WriteModifiers(modifiers); - AppendType(v.type()).Append(" ").Append(v.name()).Append(";"); - EndLine(); +SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc) { + // If present, write field javadoc only as one brief line + if (javadoc != nullptr && !javadoc->brief().empty()) { + Append("/** ").Append(javadoc->brief()).Append(" */").EndLine(); } + WriteModifiers(modifiers); + AppendType(field.type()).Append(" ").Append(field.name()).Append(";"); + EndLine(); return *this; } @@ -228,39 +255,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) { return *this; } -SourceWriter& SourceWriter::WriteDoc(const string& description, - const string& return_description, const std::list* parameters) { - if (description.empty() && return_description.empty() - && (parameters == nullptr || parameters->empty())) { - return *this; // no doc to write - } +SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) { + Append("/**").Prefix(" * ").EndLine(); bool do_line_break = false; - Append("/**").EndLine().Prefix(" * "); - if (!description.empty()) { - Write(description).EndLine(); + if (!javadoc.brief().empty()) { + Write(javadoc.brief()).EndLine(); do_line_break = true; } - if (parameters != nullptr && !parameters->empty()) { + if (!javadoc.details().empty()) { if (do_line_break) { - EndLine(); - do_line_break = false; - } - for (const Variable& v : *parameters) { - Append("@param ").Append(v.name()); - if (!v.description().empty()) { - Append(" ").Write(v.description()); - } - EndLine(); + Append("

").EndLine(); } + Write(javadoc.details()).EndLine(); + do_line_break = true; } - if (!return_description.empty()) { + if (!javadoc.tags().empty()) { if (do_line_break) { EndLine(); - do_line_break = false; } - Append("@return ").Write(return_description).EndLine(); + for (const auto& p : javadoc.tags()) { + Append("@" + p.first); + if (!p.second.empty()) { + Append(" ").Write(p.second); + } + EndLine(); + } } - return Prefix("").Append(" **/").EndLine(); + return Prefix("").Append(" */").EndLine(); } SourceWriter& SourceWriter::WriteAnnotations( @@ -278,14 +299,16 @@ SourceWriter& SourceWriter::WriteAnnotations( SourceWriter& SourceWriter::WriteGenerics( const std::list& generics) { Append("<"); + bool first = true; for (const Type* pt : generics) { - if (pt != generics.front()) { + if (!first) { Append(", "); } Append(pt->name()); if (!pt->supertypes().empty()) { Append(" extends ").AppendType(pt->supertypes().front()); } + first = false; } return Append(">"); } @@ -311,21 +334,20 @@ void SourceWriter::PopGenericNamespace() { void SourceWriter::TypeVisitor::Visit(const Type& type) { DoVisit(type); for (const Type& t : type.parameters()) { - DoVisit(t); + Visit(t); } for (const Annotation& t : type.annotations()) { DoVisit(t); } for (const Type& t : type.supertypes()) { - DoVisit(t); + Visit(t); } } void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC - && !type.IsWildcard() - && generic_names_.find(type.name()) == generic_names_.end()) { + if (type.kind() == Type::GENERIC && !type.wildcard() && + generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); } @@ -333,7 +355,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.package() + '.' + type.name()); + imports_.insert(type.canonical_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index f011acd30a..de0113bd5b 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,25 +93,22 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock() { - return Append(newline_ ? "{" : " {").EndLine().Indent(2); - } + SourceWriter& BeginBlock(const string& expression = ""); // Ends the current block of source code. // // This method appends a new closing brace to the current data and outdent the // next lines back to the margin used before BeginBlock() was invoked. - SourceWriter& EndBlock() { - return Indent(-2).Append("}").EndLine(); - } + SourceWriter& EndBlock(); // Begins to write a method. // // This method outputs the signature of the Java method from the data passed - // in the 'method' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // method. - SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + // in the 'method' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this method and, optionally, + // a Javadoc. + SourceWriter& BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -122,22 +119,24 @@ class SourceWriter { // Begins to write the main type of a source file. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this type and, optionally, + // a Javadoc. // - // If not null, all types found in the 'dependencies' list will be imported - // before declaring the new type. - SourceWriter& BeginType(const Type& clazz, - const std::list* dependencies, int modifiers = 0); + // If not null, all types found in the 'extra_dependencies' list will be + // imported before declaring the new type. + SourceWriter& BeginType(const Type& type, int modifiers, + const std::list* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. - SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the accesses and the scope of this type and, + // optionally, a Javadoc. + SourceWriter& BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -145,13 +144,13 @@ class SourceWriter { // BeginType() or BeginInnerType() prior to this. SourceWriter& EndType(); - // Writes a list of variables as fields of a type. + // Writes a variable as fields of a type. // // This method must be called within the definition of a type (see BeginType() - // or BeginInnerType()). Additional modifiers can also be passed in parameter - // to define the accesses and the scope of those fields. - SourceWriter& WriteFields(const std::list& fields, - int modifiers = 0); + // or BeginInnerType()). Modifiers are also be passed in parameter to define + // the accesses and the scope of this field and, optionally, a Javadoc. + SourceWriter& WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; @@ -207,9 +206,7 @@ class SourceWriter { std::stack generic_namespaces_; SourceWriter& WriteModifiers(int modifiers); - SourceWriter& WriteDoc(const string& description, - const string& return_description = "", - const std::list* parameters = nullptr); + SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list& annotations); SourceWriter& WriteGenerics(const std::list& generics); GenericNamespace* PushGenericNamespace(int modifiers); diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 4bce2fea70..fb8fc64dff 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -245,12 +245,17 @@ TEST(StreamTest, Types) { SourceBufferWriter writer; Type generic = Type::Generic("T").add_supertype(Type::Class("Number")); - writer.AppendType(Type::Int()).Append(", ") - .AppendType(Type::Class("String")).Append(", ") - .AppendType(generic).Append(", ") - .AppendType(Type::ListOf(generic)).Append(", ") - .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Generic())); + writer.AppendType(Type::Int()) + .Append(", ") + .AppendType(Type::Class("String")) + .Append(", ") + .AppendType(generic) + .Append(", ") + .AppendType(Type::ListOf(generic)) + .Append(", ") + .AppendType(Type::ListOf(Type::IterableOf(generic))) + .Append(", ") + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List, List>, List"; @@ -282,7 +287,7 @@ TEST(WriteType, SimpleClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -300,7 +305,7 @@ TEST(WriteType, SimpleClassWithDependencies) { deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); deps.push_back(Type::Class("NoPackageType")); - writer.BeginType(clazz, &deps, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, &deps).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -313,20 +318,22 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - clazz.description("This class has a\n

\nmultiline description."); + Javadoc clazz_doc = Javadoc::Create("Javadoc test") + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType(); const char* expected = "package org.tensorflow;\n\n" "/**\n" - " * This class has a\n" + " * Javadoc test\n" " *

\n" + " * This is a\n" " * multiline description.\n" - " **/\n" + " */\n" "@Bean\n" "@SuppressWarnings(\"rawtypes\")\n" "public class Test {\n}\n"; @@ -339,7 +346,7 @@ TEST(WriteType, ParameterizedClass) { clazz.add_parameter(Type::Generic("T")); clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -358,7 +365,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) { clazz.add_supertype(Type::Interface("Runnable")); clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -372,24 +379,23 @@ TEST(WriteType, ParameterizedClassFields) { Type clazz = Type::Class("Test", "org.tensorflow"); Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); clazz.add_parameter(type_t); - std::list static_fields; - static_fields.push_back(Variable::Create("field1", Type::Class("String"))); - std::list member_fields; - member_fields.push_back(Variable::Create("field2", Type::Class("String"))); - member_fields.push_back(Variable::Create("field3", type_t)); + Variable field1 = Variable::Create("field1", Type::Class("String")); + Variable field2 = Variable::Create("field2", Type::Class("String")); + Variable field3 = Variable::Create("field3", type_t); + Javadoc field3_doc = Javadoc::Create("This variable is documented"); - writer.BeginType(clazz, nullptr, PUBLIC) - .WriteFields(static_fields, STATIC | PUBLIC | FINAL) - .WriteFields(member_fields, PRIVATE) - .EndType(); + writer.BeginType(clazz, PUBLIC) + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) + .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test {\n" - " \n" " public static final String field1;\n" - " \n" " private String field2;\n" + " /** This variable is documented */\n" " private T field3;\n" "}\n"; ASSERT_STREQ(expected, writer.str().data()); @@ -400,10 +406,10 @@ TEST(WriteType, SimpleInnerClass) { Type clazz = Type::Class("Test", "org.tensorflow"); Type inner_class = Type::Class("InnerTest"); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginInnerType(inner_class, PUBLIC) - .EndType() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginInnerType(inner_class, PUBLIC) + .EndType() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -423,10 +429,10 @@ TEST(WriteType, StaticParameterizedInnerClass) { Type inner_class = Type::Class("InnerTest"); inner_class.add_parameter(type_t); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginInnerType(inner_class, PUBLIC | STATIC) - .EndType() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginInnerType(inner_class, PUBLIC | STATIC) + .EndType() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -443,9 +449,10 @@ TEST(WriteMethod, SimpleMethod) { Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC) + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -461,24 +468,28 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - method.description("This method has a\n

\nmultiline description."); + Javadoc method_doc = + Javadoc::Create("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test {\n" " \n" " /**\n" - " * This method has a\n" + " * Javadoc test\n" " *

\n" + " * This method has a\n" " * multiline description.\n" - " **/\n" + " */\n" " @Override\n" " @SuppressWarnings(\"rawtypes\")\n" " public void doNothing() {\n" @@ -490,23 +501,27 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { TEST(WriteMethod, DocumentedMethodWithArguments) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); + Variable reverse = Variable::Create("reverse", Type::Boolean()); Method method = Method::Create("boolToInt", Type::Int()); - method.description("Converts a boolean to an int"); - method.return_description("int value for this boolean"); method.add_argument(Variable::Create("b", Type::Boolean())); - Variable reverse = Variable::Create("reverse", Type::Boolean()); - reverse.description("if true, value is reversed"); method.add_argument(reverse); - - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) - .Append("if (b && !reverse)") - .BeginBlock() - .Append("return 1;").EndLine() - .EndBlock() - .Append("return 0;").EndLine() - .EndMethod() - .EndType(); + Javadoc method_doc = + Javadoc::Create("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); + + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) + .Append("if (b && !reverse)") + .BeginBlock() + .Append("return 1;") + .EndLine() + .EndBlock() + .Append("return 0;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -514,11 +529,13 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " \n" " /**\n" " * Converts a boolean to an int\n" + " *

\n" + " * This method will convert\n" + " * a boolean to an int\n" " * \n" - " * @param b\n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" - " **/\n" + " */\n" " public int boolToInt(boolean b, boolean reverse) {\n" " if (b && !reverse) {\n" " return 1;\n" @@ -536,11 +553,12 @@ TEST(WriteMethod, ParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) - .Append("return null;").EndLine() - .EndMethod() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC) + .Append("return null;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -560,11 +578,12 @@ TEST(WriteMethod, StaticParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC | STATIC) - .Append("return null;").EndLine() - .EndMethod() - .EndType(); + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC | STATIC) + .Append("return null;") + .EndLine() + .EndMethod() + .EndType(); const char* expected = "package org.tensorflow;\n\n" diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index a6650fc4ea..f4ff34ea03 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -1,62 +1,62 @@ # -*- Python -*- -load("//tensorflow:tensorflow.bzl", - "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", +) -# Given a list of "ops_libs" (a list of files in the core/ops directory -# without their .cc extensions), generate Java wrapper code for all operations -# found in the ops files. -# Then, combine all those source files into a single archive (.srcjar). +# Generate Java wrapper classes for all registered core operations and package +# them into a single source archive (.srcjar). # # For example: -# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ]) +# tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package") # -# will create a genrule named "gen_sources" that first generate source files: -# ops/src/main/java/my/package/array/*.java -# ops/src/main/java/my/package/math/*.java +# will create a genrule named "gen_sources" that generates source files under +# ops/src/main/java/my/package/**/*.java # -# and then archive those source files in: +# and then archive those source files into # ops/gen_sources.srcjar # def tf_java_op_gen_srcjar(name, gen_tool, - gen_base_package, - ops_libs=[], - ops_libs_pkg="//tensorflow/core", + base_package, + api_def_srcs=[], out_dir="ops/", out_src_dir="src/main/java/", visibility=["//tensorflow/java:__pkg__"]): - gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] - # Construct an op generator binary for each ops library. - for ops_lib in ops_libs: - gen_lib = ops_lib[:ops_lib.rfind("_")] - out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + "$$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) - tf_cc_binary( - name=out_gen_tool, - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) - - gen_tools += [":" + out_gen_tool] - gen_cmds += ["$(location :" + out_gen_tool + ")" + - " --output_dir=$(@D)/" + out_src_dir + - " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package] + gen_cmds += ["$(location " + gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --base_package=" + base_package + + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] - gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"] - gen_tools += tf_binary_additional_srcs() + native.genrule( name=name, + srcs=srcs, outs=[gen_srcjar], - tools=gen_tools, - cmd="&&".join(gen_cmds)) + tools=[ + "@local_jdk//:jar", + "@local_jdk//:jdk", + gen_tool + ] + tf_binary_additional_srcs(), + cmd=" && ".join(gen_cmds)) -- cgit v1.2.3