diff options
author | Karl Lessard <karl@kubx.ca> | 2018-03-29 05:28:16 +0800 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-05-03 22:40:42 -0400 |
commit | fe9b2637cfe39cf11eb3d0494948a733b7fc1d7d (patch) | |
tree | 280fa51b3d80c9b9f4b0e8e98f55831a384f0851 /tensorflow/java | |
parent | 2dc7575123ffa0e6413fc3d2700968ef25f049de (diff) |
Parse op definition and generate a Java Op class.
Diffstat (limited to 'tensorflow/java')
-rw-r--r-- | tensorflow/java/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/java_defs.h | 76 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_gen_main.cc | 22 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.cc | 406 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.h | 42 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_parser.cc | 417 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_parser.h | 137 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer.cc | 127 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer.h | 55 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer_test.cc | 82 | ||||
-rw-r--r-- | tensorflow/java/src/gen/gen_ops.bzl | 29 | ||||
-rw-r--r-- | tensorflow/java/src/gen/resources/license.snippet.java | 14 |
12 files changed, 1201 insertions, 210 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index ab7d698a45..635a4e807d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -70,6 +70,7 @@ filegroup( tf_java_op_gen_srcjar( name = "java_op_gen_sources", + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -111,11 +112,13 @@ cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", + "src/gen/cc/op_parser.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", + "src/gen/cc/op_parser.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), @@ -124,6 +127,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:op_gen_lib", ], ) diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 59f8beaee7..2065477f58 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -18,12 +18,15 @@ limitations under the License. #include <string> #include <list> +#include <map> +#include <utility> namespace tensorflow { namespace java { // An enumeration of different modifiers commonly used in Java enum Modifier { + PACKAGE = 0, PUBLIC = (1 << 0), PROTECTED = (1 << 1), PRIVATE = (1 << 2), @@ -72,6 +75,12 @@ class Type { // Reflection API does return Type(Type::PRIMITIVE, "void"); } + static Type Generic(const string& name) { + return Type(Type::GENERIC, name); + } + static Type Wildcard() { + return Type(Type::GENERIC, ""); + } static Type Class(const string& name, const string& package = "") { return Type(Type::CLASS, name, package); } @@ -81,9 +90,6 @@ class Type { static Type Enum(const string& name, const string& package = "") { return Type(Type::ENUM, name, package); } - static Type Generic(const string& name = "") { - return Type(Type::GENERIC, name); - } static Type ClassOf(const Type& type) { return Class("Class").add_parameter(type); } @@ -96,11 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string& description() const { return description_; } - Type& description(const string& description) { - description_ = description; - return *this; + const string full_name() const { + return package_.empty() ? name_ : package_ + "." + name_; } + bool unknown() const { return name_.empty(); } // only wildcards has no name const std::list<Type>& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); @@ -120,14 +125,6 @@ class Type { } return *this; } - // Returns true if "type" is of a known collection type (only a few for now) - bool IsCollection() const { - return name_ == "List" || name_ == "Iterable"; - } - // Returns true if this instance is a wildcard (<?>) - bool IsWildcard() const { - return kind_ == GENERIC && name_.empty(); - } protected: Type(Kind kind, const string& name, const string& package = "") @@ -137,7 +134,6 @@ class Type { Kind kind_; string name_; string package_; - string description_; std::list<Type> parameters_; std::list<Annotation> annotations_; std::list<Type> supertypes_; @@ -180,16 +176,11 @@ class Variable { const string& name() const { return name_; } const Type& type() const { return type_; } bool variadic() const { return variadic_; } - const string& description() const { return description_; } - Variable& description(const string& description) { - description_ = description; - return *this; - } + private: string name_; Type type_; bool variadic_; - string description_; Variable(const string& name, const Type& type, bool variadic) : name_(name), type_(type), variadic_(variadic) {} @@ -210,16 +201,6 @@ class Method { bool constructor() const { return constructor_; } const string& name() const { return name_; } const Type& return_type() const { return return_type_; } - const string& description() const { return description_; } - Method& description(const string& description) { - description_ = description; - return *this; - } - const string& return_description() const { return return_description_; } - Method& return_description(const string& description) { - return_description_ = description; - return *this; - } const std::list<Variable>& arguments() const { return arguments_; } Method& add_argument(const Variable& var) { arguments_.push_back(var); @@ -235,8 +216,6 @@ class Method { string name_; Type return_type_; bool constructor_; - string description_; - string return_description_; std::list<Variable> arguments_; std::list<Annotation> annotations_; @@ -244,6 +223,35 @@ class Method { : name_(name), return_type_(return_type), constructor_(constructor) {} }; +// A definition of a documentation bloc for a Java element (JavaDoc) +class Javadoc { + public: + static Javadoc Create(const string& brief = "") { + return Javadoc(brief); + } + const string& brief() const { return brief_; } + const string& details() const { return description_; } + Javadoc& details(const string description) { + description_ = description; + return *this; + } + const std::list<std::pair<string, string>> tags() const { return tags_; } + Javadoc& add_tag(const string& tag, const string& text) { + tags_.push_back(std::make_pair(tag, text)); + return *this; + } + Javadoc& add_param_tag(const string& name, const string& text) { + return add_tag("param", name + " " + text); + } + + private: + string brief_; + string description_; + std::list<std::pair<string, string>> tags_; + + explicit Javadoc(const string& brief) : brief_(brief) {} +}; + } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index bea99f3d7f..015200023f 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -48,8 +48,11 @@ const char kUsageHeader[] = "through\n" "the 'org.tensorflow.op.Ops' API as a group until the generated classes " "are compiled using an appropriate annotation processor.\n\n" - "Finally, the '--base_package' overrides the default parent package " - "under which the generated subpackage and classes are to be located.\n\n"; + "The '--base_package' overrides the default parent package under which " + "the generated subpackage and classes are to be located.\n\n" + "Finally, a list of directories of API proto definitions can be provided " + "to override default values found in the ops definitions, ordered by\n" + "priority (the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow @@ -60,7 +63,7 @@ int main(int argc, char* argv[]) { tensorflow::string base_package = "org.tensorflow.op"; std::vector<tensorflow::Flag> flag_list = { tensorflow::Flag("output_dir", &output_dir, - "Root directory into which output files are generated"), + "Root directory into which output files are generated"), tensorflow::Flag( "lib_name", &lib_name, "A name, in snake_case, used to classify this set of operations"), @@ -72,12 +75,15 @@ int main(int argc, char* argv[]) { bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - - tensorflow::java::OpGenerator generator; + std::vector<tensorflow::string> api_dirs; + if (argc > 1) { + api_dirs = tensorflow::str_util::Split(argv[1], ",", + tensorflow::str_util::SkipEmpty()); + } + tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; - tensorflow::OpRegistry::Global()->Export(true, &ops); - tensorflow::Status status = - generator.Run(ops, lib_name, base_package, output_dir); + tensorflow::OpRegistry::Global()->Export(false, &ops); + tensorflow::Status status = generator.Run(ops, lib_name); TF_QCHECK_OK(status); return 0; diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index def06baf2d..c9b57f5706 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -14,53 +14,409 @@ limitations under the License. ==============================================================================*/ #include <string> +#include <map> +#include <vector> +#include <list> +#include <memory> +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" +#include "tensorflow/java/src/gen/cc/source_writer.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" namespace tensorflow { namespace java { namespace { -string CamelCase(const string& str, char delimiter, bool upper) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == delimiter) { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; +const char* kLicenseSnippet = + "tensorflow/java/src/gen/resources/license.snippet.java"; + +const std::map<string, Type> kPrimitiveAttrTypes = { + { "Boolean", Type::Boolean() }, + { "Byte", Type::Byte() }, + { "Character", Type::Byte() }, + { "Float", Type::Float() }, + { "Integer", Type::Long() }, + { "Long", Type::Long() }, + { "Short", Type::Long() }, + { "Double", Type::Float() }, +}; + +enum RenderMode { + DEFAULT, + SINGLE_OUTPUT, + SINGLE_LIST_OUTPUT +}; + +void CollectOpDependencies(const OpSpec& op, RenderMode mode, + std::list<Type>* out) { + out->push_back(Type::Class("Operation", "org.tensorflow")); + out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); + out->push_back(Type::Class("Scope", "org.tensorflow.op")); + if (mode == SINGLE_OUTPUT) { + out->push_back(Type::Class("Output", "org.tensorflow")); + } else if (mode == SINGLE_LIST_OUTPUT) { + out->push_back(Type::Interface("Iterator", "java.util")); + } + // Don't pay attention to duplicate types in the dependency list, they will + // be filtered out by the SourceWriter. + for (const OpSpec::Operand& input : op.inputs()) { + out->push_back(input.var().type()); + if (input.iterable()) { + out->push_back(Type::Class("Operands", "org.tensorflow.op")); + } + } + for (const OpSpec::Operand& output : op.outputs()) { + out->push_back(output.var().type()); + if (output.iterable()) { + out->push_back(Type::Class("Arrays", "java.util")); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + out->push_back(attribute.var().type()); + if (attribute.var().type().name() == "Class") { + out->push_back(Type::Enum("DataType", "org.tensorflow")); + } + } + for (const OpSpec::Operand& option : op.options()) { + out->push_back(option.var().type()); + } +} + +void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, + SourceWriter* writer) { + string var = optional ? "opts." + attr.var().name() : attr.var().name(); + if (attr.iterable()) { + const Type& type = attr.data_type(); + std::map<string, Type>::const_iterator it = + kPrimitiveAttrTypes.find(type.name()); + if (it != kPrimitiveAttrTypes.end()) { + string array = attr.var().name() + "Array"; + writer->AppendType(it->second) + .Append("[] " + array + " = new ") + .AppendType(it->second) + .Append("[" + var + ".size()];") + .EndLine(); + writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)") + .Append(array + "[i] = " + var + ".get(i);") + .EndLine() + .EndBlock() + .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append(");") + .EndLine(); } else { - result += c; + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + .Append(".toArray(new ") + .AppendType(type) + .Append("[" + var + ".size()]));") + .EndLine(); } + } else { + Type type = attr.var().type(); + writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + if (type.name() == "Class") { + writer->Append("DataType.fromClass(" + attr.var().name() + "));"); + } else { + writer->Append(var + ");"); + } + writer->EndLine(); } - return result; } -} // namespace +void RenderFactoryMethod(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method factory = Method::Create("create", op_class); + Javadoc factory_doc = Javadoc::Create( + "Factory method to create a class to wrap a new " + op_class.name() + + " operation to the graph."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + factory.add_argument(scope); + factory_doc.add_param_tag(scope.name(), "Current graph scope"); + for (const OpSpec::Operand& input : op.inputs()) { + factory.add_argument(input.var()); + factory_doc.add_param_tag(input.var().name(), input.description()); + } + for (const OpSpec::Operand& attribute : op.attributes()) { + factory.add_argument(attribute.var()); + factory_doc.add_param_tag(attribute.var().name(), attribute.description()); + } + if (!op.options().empty()) { + factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); + factory_doc.add_param_tag("options", "carries optional attributes values"); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); + writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); + writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" + + op.graph_name() + "\", scope.makeOpName(\"" + + op_class.name() + "\"));"); + writer->EndLine(); -OpGenerator::OpGenerator() : env(Env::Default()) {} + for (const OpSpec::Operand& input : op.inputs()) { + if (input.iterable()) { + writer->Append("opBuilder.addInputList(Operands.asOutputs(" + + input.var().name() + "));"); + writer->EndLine(); + } else { + writer->Append("opBuilder.addInput(" + input.var().name() + + ".asOutput());"); + writer->EndLine(); + } + } + for (const OpSpec::Operand& attribute : op.attributes()) { + WriteSetAttrDirective(attribute, false, writer); + } + if (!op.options().empty()) { + writer->BeginBlock("if (options != null)") + .BeginBlock("for (Options opts : options)"); + for (const OpSpec::Operand& option : op.options()) { + writer->BeginBlock("if (opts." + option.var().name() + " != null)"); + WriteSetAttrDirective(option, true, writer); + writer->EndBlock(); + } + writer->EndBlock().EndBlock(); + } + writer->Append("return new ") + .AppendType(op_class) + .Append("(opBuilder.build());") + .EndLine(); + writer->EndMethod(); +} -OpGenerator::~OpGenerator() {} +void RenderConstructor(const OpSpec& op, const Type& op_class, + SourceWriter* writer) { + Method constructor = Method::ConstructorFor(op_class) + .add_argument( + Variable::Create("operation", + Type::Class("Operation", "org.tensorflow"))); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable() && !output.data_type().unknown()) { + constructor.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + break; + } + } + writer->BeginMethod(constructor, PRIVATE) + .Append("super(operation);") + .EndLine(); + if (op.outputs().size() > 0) { + writer->Append("int outputIdx = 0;") + .EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + if (output.iterable()) { + string var_length = output.var().name() + "Length"; + writer->Append("int " + var_length) + .Append(" = operation.outputListLength(\"" + output.graph_name() + + "\");") + .EndLine() + .Append(output.var().name() + " = Arrays.asList("); + if (!output.data_type().unknown()) { + writer->Append("(") + .AppendType(output.var().type().parameters().front()) + .Append("[])"); + } + writer->Append("operation.outputList(outputIdx, " + var_length + "));") + .EndLine() + .Append("outputIdx += " + var_length + ";") + .EndLine(); + } else { + writer->Append(output.var().name() + + " = operation.output(outputIdx++);") + .EndLine(); + } + } + } + writer->EndMethod(); +} -Status OpGenerator::Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir) { - const string package = - base_package + '.' + str_util::StringReplace(lib_name, "_", "", true); - const string package_path = - output_dir + '/' + str_util::StringReplace(package, ".", "/", true); - const string group = CamelCase(lib_name, '_', false); +void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), Type::Class("Options")) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) + .Append("return new Options()." + option.var().name() + "(" + + option.var().name() + ");") + .EndLine() + .EndMethod(); + } + for (const OpSpec::Operand& output : op.outputs()) { + Method getter = Method::Create(output.var().name(), output.var().type()); + Javadoc getter_doc = Javadoc::Create(output.description()); + writer->BeginMethod(getter, PUBLIC, &getter_doc) + .Append("return " + output.var().name() + ";") + .EndLine() + .EndMethod(); + } +} + +void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, + SourceWriter* writer) { + OpSpec::Operand output = op.outputs().front(); + + if (mode == SINGLE_OUTPUT) { + bool cast2obj = output.data_type().unknown(); + Type return_type = Type::Class("Output", "org.tensorflow") + .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + Method as_output = Method::Create("asOutput", return_type) + .add_annotation(Annotation::Create("Override")); + if (cast2obj) { + as_output.add_annotation( + Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); + } + writer->BeginMethod(as_output, PUBLIC); + if (cast2obj) { + writer->Append("return (").AppendType(return_type).Append(") "); + } else { + writer->Append("return "); + } + writer->Append(output.var().name() + ";") + .EndLine() + .EndMethod(); + + } else if (mode == SINGLE_LIST_OUTPUT) { + Type operand = Type::Interface("Operand", "org.tensorflow"); + if (output.data_type().unknown()) { + operand.add_parameter(Type::Class("Object")); + } else { + operand.add_parameter(output.data_type()); + } + Type return_type = Type::Interface("Iterator", "java.util") + .add_parameter(operand); + Method iterator = Method::Create("iterator", return_type) + .add_annotation(Annotation::Create("Override")) + .add_annotation(Annotation::Create("SuppressWarnings") + .attributes("{\"rawtypes\", \"unchecked\"}")); + // cast the output list using a raw List + writer->BeginMethod(iterator, PUBLIC) + .Append("return (" + return_type.name() + ") ") + .Append(output.var().name() + ".iterator();") + .EndLine() + .EndMethod(); + } +} + +void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { + Type options_class = Type::Class("Options"); + Javadoc options_doc = Javadoc::Create( + "Class holding optional attributes of this operation"); + writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); + for (const OpSpec::Operand& option : op.options()) { + Method setter = Method::Create(option.var().name(), options_class) + .add_argument(option.var()); + Javadoc setter_doc = Javadoc::Create() + .add_param_tag(option.var().name(), option.description()); + writer->BeginMethod(setter, PUBLIC, &setter_doc) + .Append("this." + option.var().name() + " = " + option.var().name() + + ";") + .EndLine() + .Append("return this;") + .EndLine() + .EndMethod(); + } + writer->EndLine(); + for (const OpSpec::Operand& option : op.options()) { + writer->WriteField(option.var(), PRIVATE); + } + Method constructor = Method::ConstructorFor(options_class); + writer->BeginMethod(constructor, PRIVATE).EndMethod(); + writer->EndType(); +} - if (!env->FileExists(package_path).ok()) { - TF_CHECK_OK(env->RecursivelyCreateDir(package_path)); +void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, + SourceWriter* writer) { + RenderMode mode = DEFAULT; + if (op.outputs().size() == 1) { + mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + } + std::list<Type> dependencies; + CollectOpDependencies(op, mode, &dependencies); + const Type& op_class = endpoint.type(); + writer->WriteFromFile(kLicenseSnippet) + .EndLine() + .Append("// This file is machine generated, DO NOT EDIT!") + .EndLine() + .EndLine() + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); + if (!op.options().empty()) { + RenderOptionsClass(op, writer); } + RenderFactoryMethod(op, op_class, writer); + RenderGettersAndSetters(op, writer); + if (mode != DEFAULT) { + RenderInterfaceImpl(op, mode, writer); + } + writer->EndLine(); + for (const OpSpec::Operand& output : op.outputs()) { + writer->WriteField(output.var(), PRIVATE); + } + RenderConstructor(op, op_class, writer); + writer->EndType(); +} + +} // namespace + +OpGenerator::OpGenerator(const string& base_package, const string& output_dir, + const std::vector<string>& api_dirs, Env* env) + : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs), + env_(env) { +} +Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; - // TODO(karllessard) generate wrappers from list of ops + ApiDefMap api_map(op_list); + if (!api_dirs_.empty()) { + // Only load api files that correspond to the requested "op_list" + for (const auto& op : op_list.op()) { + for (const auto& api_def_dir : api_dirs_) { + const std::string api_def_file_pattern = + io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); + if (env_->FileExists(api_def_file_pattern).ok()) { + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + } + } + } + } + api_map.UpdateDocs(); + for (const auto& op_def : op_list.op()) { + const ApiDef* api_def = api_map.GetApiDef(op_def.name()); + if (api_def->visibility() != ApiDef::SKIP) { + Status status = GenerateOp(op_def, *api_def, lib_name); + if (status != Status::OK()) { + LOG(ERROR) << "Fail to generate Java wrapper for operation \"" + << op_def.name() << "\""; + } + } + } + return Status::OK(); +} + +Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name) { + std::unique_ptr<OpSpec> op; + OpParser op_parser(op_def, api_def, lib_name, base_package_); + op_parser.Parse(&op); + for (const OpSpec::Endpoint& endpoint : op->endpoints()) { + string package_path = io::JoinPath(output_dir_, + str_util::StringReplace(endpoint.type().package(), ".", "/", true)); + if (!env_->FileExists(package_path).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); + } + string file_path = + io::JoinPath(package_path, endpoint.type().name() + ".java"); + std::unique_ptr<tensorflow::WritableFile> file; + TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); + SourceFileWriter writer(file.get()); + RenderEndpoint(*op, endpoint, &writer); + } return Status::OK(); } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 4b55ed3ed9..19d8db95fb 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -17,34 +17,42 @@ limitations under the License. #define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_ #include <string> +#include <vector> -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { namespace java { -/// \brief A generator of Java operation wrappers. -/// -/// Such generator is normally ran only once per executable, outputting -/// wrappers for the all registered operations it has been compiled with. -/// Nonetheless, it is designed to support multiple runs, giving a different -/// list of operations on each cycle. +// A generator of Java operation wrappers. +// +// Such generator is normally ran only once per executable, outputting +// wrappers for the all registered operations it has been compiled with. +// Nonetheless, it is designed to support multiple runs, giving a different +// list of operations on each cycle. class OpGenerator { public: - OpGenerator(); - virtual ~OpGenerator(); + OpGenerator(const string& base_package, const string& output_dir, + const std::vector<string>& api_dirs, Env* env = Env::Default()); + virtual ~OpGenerator() = default; - /// \brief Generates wrappers for the given list of 'ops'. - /// - /// Output files are generated in <output_dir>/<base_package>/<lib_package>, - /// where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& ops, const string& lib_name, - const string& base_package, const string& output_dir); + // Generates wrappers for the given list of 'ops'. + // + // Output files are generated in <output_dir>/<base_package>/<lib_package>, + // where 'lib_package' is derived from 'lib_name'. + Status Run(const OpList& op_list, const string& lib_name); private: - Env* env; + string base_package_; + string output_dir_; + std::vector<string> api_dirs_; + Env* env_; + + Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc new file mode 100644 index 0000000000..0541e343d8 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_parser.cc @@ -0,0 +1,417 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <map> +#include <vector> +#include <utility> +#include <string> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/java/src/gen/cc/op_parser.h" + +namespace tensorflow { +namespace java { +namespace { + +string SnakeToCamelCase(const string& str, bool upper = false) { + string result; + bool cap = upper; + for (string::const_iterator it = str.begin(); it != str.end(); ++it) { + const char c = *it; + if (c == '_') { + cap = true; + } else if (cap) { + result += toupper(c); + cap = false; + } else { + result += c; + } + } + return result; +} + +bool IsRealNumber(DataType type) { + for (DataType dt : RealNumberTypes()) { + if (type == dt) { + return true; + } + } + return false; +} + +bool IsRealNumbers(const AttrValue& values) { + if (values.has_list()) { + for (int i = 0; i < values.list().type_size(); ++i) { + if (!IsRealNumber(values.list().type(i))) { + return false; + } + } + return true; + } + return IsRealNumber(values.type()); +} + +string ParseDocumentation(const string& text) { + std::stringstream javadoc_text; + string::const_iterator c_iter = text.cbegin(); + bool code = false; + bool emphasis = false; + bool list = false; + while (c_iter != text.cend()) { + char c = *c_iter++; + int count = 1; + switch (c) { + case '\n': + if (!code) { + // consumes all subsequent newlines, if there are more than one, + // then there are two choices: + // - if the next line starts with an asterisk, we are enumerating + // a list of items + // - otherwise, we are starting a new paragraph + for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} + if (c_iter != text.cend()) { + if (count > 1) { + if (*c_iter != '*' && list) { + javadoc_text << "</li>\n</ul>\n"; + list = false; + } else if (*c_iter == '*' && !list) { + javadoc_text << "\n<ul>\n<li>"; + list = true; + c_iter++; + } else { + javadoc_text << "\n<p>\n"; + } + } else if (list && *c_iter == '*') { + javadoc_text << "</li>\n<li>"; + c_iter++; + } else { + javadoc_text << '\n'; + } + } + } + break; + case '`': + // consumes all subsequent backquotes, those are use enclose code. + // if there are more than 3, we are dealing with a pre-formatted block, + // otherwise it is a single-line code snippet + for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} + if (count >= 3) { + javadoc_text << (code ? "\n}</pre>" : "<pre>{@code\n"); + } else { + javadoc_text << (code ? "}" : "{@code "); + } + code = !code; + break; + case '*': + if (!code) { + // consumes all subsequent asterisks, if there are more than one, then + // we put the text in bold, otherwise in italic + for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {} + if (count > 1) { + javadoc_text << (emphasis ? "</b>" : "<b>"); + } else { + javadoc_text << (emphasis ? "</i>" : "<i>"); + } + emphasis = !emphasis; + } else { + javadoc_text << '*'; + } + break; + default: + javadoc_text << c; + break; + } + } + return javadoc_text.str(); +} + +} // namespace + +OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def, + const string& lib_name, const string& base_package) + : op_def_(op_def), op_api_(api_def), lib_name_(lib_name), + base_package_(base_package) { +} + +void OpParser::Parse(std::unique_ptr<OpSpec>* op_ptr) { + visited_attrs_.clear(); + next_generic_ = 'T'; + op_ptr->reset(new OpSpec(op_api_.graph_op_name())); + for (const string& next_input_name : op_api_.arg_order()) { + for (int i = 0; i < op_def_.input_arg().size(); ++i) { + if (op_def_.input_arg(i).name() == next_input_name) { + ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get()); + break; + } + } + } + for (int i = 0; i < op_def_.attr().size(); ++i) { + ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get()); + } + for (int i = 0; i < op_def_.output_arg().size(); ++i) { + ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get()); + } + BuildEndpoints(op_ptr->get()); +} + +void OpParser::BuildEndpoints(OpSpec* op) { + Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary())) + .details(ParseDocumentation(op_api_.description())); + std::vector<Type> op_supertypes; + op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op")); + std::map<string, const Type*> op_generics; + for (const OpSpec::Operand& output : op->outputs()) { + // declare generic output parameters at the Op class level + const Type& data_type = output.data_type(); + if (data_type.kind() == Type::GENERIC && !data_type.unknown() + && op_generics.find(data_type.name()) == op_generics.end()) { + op_generics.insert(std::make_pair(data_type.name(), &data_type)); + op_doc.add_param_tag("<" + data_type.name() + ">", + "data type of output '" + output.var().name() + "'"); + } + // implement the Op as an (iteration of) Operand if it has only one output + if (op->outputs().size() == 1) { + Type operand_inf(Type::Interface("Operand", "org.tensorflow")); + operand_inf.add_parameter(data_type.unknown() ? + Type::Class("Object") : data_type); + op_supertypes.push_back(output.iterable() ? + Type::IterableOf(operand_inf) : operand_inf); + } + } + for (const auto& endpoint_def : op_api_.endpoint()) { + std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), "."); + // if the endpoint specifies a package, use it, otherwise derive it from the + // op library name. + string name; + string package; + if (name_tokens.size() > 1) { + package = str_util::Lowercase(name_tokens.at(0)); + name = name_tokens.at(1); + } else { + package = str_util::StringReplace(lib_name_, "_", "", true); + name = name_tokens.at(0); + } + Type endpoint(Type::Class(name, base_package_ + "." + package)); + Javadoc endpoint_doc(op_doc); + for (const auto& parameter : op_generics) { + endpoint.add_parameter(*parameter.second); + } + for (const Type& supertype : op_supertypes) { + endpoint.add_supertype(supertype); + } + if (endpoint_def.deprecation_version() > 0) { + string explanation; + if (op_api_.endpoint(0).deprecation_version() == 0) { + explanation = ", use {@link " + + op->endpoints().at(0).type().full_name() + + "} instead"; + } else { + explanation = op_def_.deprecation().explanation(); + } + endpoint_doc.add_tag("deprecated", explanation); + endpoint.add_annotation(Annotation::Create("Deprecated")); + } + // only visible ops should be annotated for exposure in the Ops Graph API + if (op_api_.visibility() != ApiDef::HIDDEN) { + string group_name = SnakeToCamelCase(lib_name_); + endpoint.add_annotation( + Annotation::Create("Operator", "org.tensorflow.op.annotation") + .attributes("group = \"" + group_name + "\"")); + } + op->add_endpoint(endpoint, endpoint_doc); + } +} + +void OpParser::ParseInput(const OpDef_ArgDef& input_def, + const ApiDef::Arg& input_api, OpSpec* op) { + bool iterable = false; + Type data_type = DataTypeOf(input_def, &iterable); + Type type = Type::Interface("Operand", "org.tensorflow") + .add_parameter(data_type); + if (iterable) { + type = Type::IterableOf(type); + } + op->add_input(OpSpec::Operand(input_api.name(), + Variable::Create(SnakeToCamelCase(input_api.rename_to()), type), + data_type, + ParseDocumentation(input_api.description()), + iterable)); +} + +void OpParser::ParseOutput(const OpDef_ArgDef& output_def, + const ApiDef::Arg& output_api, OpSpec* op) { + bool iterable = false; + Type data_type = DataTypeOf(output_def, &iterable); + Type type = Type::Class("Output", "org.tensorflow") + .add_parameter(data_type); + if (iterable) { + type = Type::ListOf(type); + } + op->add_output(OpSpec::Operand(output_api.name(), + Variable::Create(SnakeToCamelCase(output_api.rename_to()), type), + data_type, + ParseDocumentation(output_api.description()), + iterable)); +} + +void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def, + const ApiDef::Attr& attr_api, OpSpec* op) { + // do not parse attributes already visited, they have probably been inferred + // before as an input argument type + if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) { + return; + } + bool iterable = false; + Type data_type = DataTypeOf(attr_def, &iterable); + // generic attributes should be passed as an explicit type + bool explicit_type = data_type.kind() == Type::GENERIC && !iterable; + Type type = explicit_type ? + Type::Class("Class").add_parameter(data_type) : data_type; + if (iterable) { + type = Type::ListOf(data_type); + } + OpSpec::Operand attr(attr_api.name(), + Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type), + data_type, + ParseDocumentation(attr_api.description()), + iterable); + // attributes with a default value are optional + if (attr_api.has_default_value() && !explicit_type) { + op->add_option(attr); + } else { + op->add_attribute(attr); + } + visited_attrs_.insert(std::make_pair(attr_api.name(), data_type)); +} + +Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) { + if (!arg.number_attr().empty()) { + visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int())); + *iterable_out = true; + } + if (arg.type() != DataType::DT_INVALID) { + // resolve type from DataType + switch (arg.type()) { + case DataType::DT_BOOL: + return Type::Class("Boolean"); + + case DataType::DT_STRING: + return Type::Class("String"); + + case DataType::DT_FLOAT: + return Type::Class("Float"); + + case DataType::DT_DOUBLE: + return Type::Class("Double"); + + case DataType::DT_UINT8: + return Type::Class("UInt8", "org.tensorflow.types"); + + case DataType::DT_INT32: + return Type::Class("Integer"); + + case DataType::DT_INT64: + return Type::Class("Long"); + + case DataType::DT_RESOURCE: + // TODO(karllessard) create a Resource utility class that could be + // used to store a resource and its type (passed in a second argument). + // For now, we need to force a wildcard and we will unfortunately lose + // track of the resource type. + return Type::Wildcard(); + + default: + break; + } + } else { + // resolve type from type attribute + string attr_name = arg.type_attr(); + if (attr_name.empty()) { + attr_name = arg.type_list_attr(); + if (!attr_name.empty()) { + *iterable_out = true; + Type type = Type::Wildcard(); + visited_attrs_.insert(std::make_pair(attr_name, type)); + return type; + } + } + for (const auto& attr : op_def_.attr()) { + if (attr.name() == attr_name) { + Type type = DataTypeOf(attr, iterable_out); + visited_attrs_.insert(std::make_pair(attr_name, type)); + return type; + } + } + } + LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown"; + return Type::Wildcard(); +} + +Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) { + std::map<string, Type>::const_iterator it = visited_attrs_.find(attr.name()); + if (it != visited_attrs_.cend()) { + return it->second; + } + string attr_type = attr.type(); + if (attr.type().compare(0, 5, "list(") == 0) { + attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5); + *iterable_out = true; + } + if (attr_type == "type") { + if (*iterable_out) { + return Type::Enum("DataType", "org.tensorflow"); + } + return GetNextGenericTensorType(attr.allowed_values()); + } + if (attr_type == "string") { + return Type::Class("String"); + } + if (attr_type == "int") { + return Type::Class("Integer"); + } + if (attr_type == "float") { + return Type::Class("Float"); + } + if (attr_type == "bool") { + return Type::Class("Boolean"); + } + if (attr_type == "shape") { + return Type::Class("Shape", "org.tensorflow"); + } + if (attr_type == "tensor") { + return Type::Class("Tensor", "org.tensorflow") + .add_parameter(Type::Wildcard()); + } + LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown"; + return *iterable_out ? Type::Wildcard() : Type::Class("Object"); +} + +Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values) { + Type generic = Type::Generic(string(1, next_generic_)); + next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1; + + // when only real numbers are allowed, enforce that restriction in the Java by + // extending the generic from java.lang.Number + if (IsRealNumbers(allowed_values)) { + generic.add_supertype(Type::Class("Number")); + } + return generic; +} + +} // namespace java +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h new file mode 100644 index 0000000000..42855127cc --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_parser.h @@ -0,0 +1,137 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_ + +#include <map> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" + +namespace tensorflow { +namespace java { + +// Specification of a TensorFlow operation to generate. +// +// This is the result of an operation definition parsing, see OpParser::Parse(). +class OpSpec { + public: + class Endpoint { + public: + Endpoint(const Type& type, const Javadoc& javadoc) + : type_(type), javadoc_(javadoc) {} + const Type& type() const { return type_; } + const Javadoc& javadoc() const { return javadoc_; } + + private: + Type type_; + Javadoc javadoc_; + }; + + class Operand { + public: + Operand(const string& graph_name, const Variable& var, + const Type& data_type, const string& description, bool iterable) + : graph_name_(graph_name), var_(var), data_type_(data_type), + description_(description), iterable_(iterable) {} + const string& graph_name() const { return graph_name_; } + const Variable& var() const { return var_; } + Variable* var_ptr() { return &var_; } + const Type& data_type() const { return data_type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } + + private: + string graph_name_; + Variable var_; + Type data_type_; + string description_; + bool iterable_; + }; + + explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {} + const string& graph_name() const { return graph_name_; } + const std::vector<Endpoint> endpoints() const { return endpoints_; } + void add_endpoint(const Type& type, const Javadoc& javadoc) { + endpoints_.push_back(Endpoint(type, javadoc)); + } + const std::vector<Operand>& inputs() const { return inputs_; } + void add_input(const Operand& input) { + inputs_.push_back(input); + } + const std::vector<Operand>& outputs() const { return outputs_; } + void add_output(const Operand& output) { + outputs_.push_back(output); + } + const std::vector<Operand>& attributes() const { return attributes_; } + void add_attribute(const Operand& attribute) { + attributes_.push_back(attribute); + } + const std::vector<Operand>& options() const { return options_; } + void add_option(const Operand& option) { + options_.push_back(option); + } + + private: + string graph_name_; + std::vector<Endpoint> endpoints_; + std::vector<Operand> inputs_; + std::vector<Operand> outputs_; + std::vector<Operand> attributes_; + std::vector<Operand> options_; +}; + +// A parser of ops proto definitions. +// +// This object parses the definition and the api of an TensorFlow operation to +// produce a specification that can be used for Java source code rendering. +class OpParser { + public: + OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name, + const string& base_package); + virtual ~OpParser() = default; + + // Produces an operation specification from its proto definitions. + void Parse(std::unique_ptr<OpSpec>* op_ptr); + + private: + OpDef op_def_; + ApiDef op_api_; + string lib_name_; + string base_package_; + std::map<string, Type> visited_attrs_; + char next_generic_ = 0; + + void BuildEndpoints(OpSpec* op); + void ParseInput(const OpDef_ArgDef& input_def, + const ApiDef::Arg& input_api, OpSpec* op); + void ParseOutput(const OpDef_ArgDef& output_def, + const ApiDef::Arg& output_api, OpSpec* op); + void ParseAttribute(const OpDef_AttrDef& attr_def, + const ApiDef::Attr& attr_api, OpSpec* op); + Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); + Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + Type GetNextGenericTensorType(const AttrValue& allowed_values); +}; + +} // namespace java +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_ diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index a02f75ad6e..b1de5af6ba 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -15,7 +15,7 @@ limitations under the License. #include <string> #include <algorithm> -#include <deque> +#include <list> #include "tensorflow/java/src/gen/cc/source_writer.h" @@ -83,20 +83,20 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { } SourceWriter& SourceWriter::AppendType(const Type& type) { - if (type.kind() == Type::Kind::GENERIC && type.name().empty()) { + if (type.unknown()) { Append("?"); } else { Append(type.name()); - } - if (!type.parameters().empty()) { - Append("<"); - for (const Type& t : type.parameters()) { - if (&t != &type.parameters().front()) { - Append(", "); + if (!type.parameters().empty()) { + Append("<"); + for (const Type& t : type.parameters()) { + if (&t != &type.parameters().front()) { + Append(", "); + } + AppendType(t); } - AppendType(t); + Append(">"); } - Append(">"); } return *this; } @@ -107,7 +107,21 @@ SourceWriter& SourceWriter::EndLine() { return *this; } -SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { +SourceWriter& SourceWriter::BeginBlock(const string& expression) { + if (!expression.empty()) { + Append(expression + " {"); + } else { + Append(newline_ ? "{" : " {"); + } + return EndLine().Indent(2); +} + +SourceWriter& SourceWriter::EndBlock() { + return Indent(-2).Append("}").EndLine(); +} + +SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); if (!method.constructor()) { generic_namespace->Visit(method.return_type()); @@ -116,8 +130,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) { generic_namespace->Visit(v.type()); } EndLine(); - WriteDoc(method.description(), method.return_description(), - &method.arguments()); + if (javadoc != nullptr) { + WriteJavadoc(*javadoc); + } if (!method.annotations().empty()) { WriteAnnotations(method.annotations()); } @@ -145,29 +160,35 @@ SourceWriter& SourceWriter::EndMethod() { return *this; } -SourceWriter& SourceWriter::BeginType(const Type& type, - const std::list<Type>* dependencies, int modifiers) { +SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers, + const std::list<Type>* extra_dependencies, const Javadoc* javadoc) { if (!type.package().empty()) { Append("package ").Append(type.package()).Append(";").EndLine(); } - if (dependencies != nullptr && !dependencies->empty()) { - TypeImporter type_importer(type.package()); - for (const Type& t : *dependencies) { + TypeImporter type_importer(type.package()); + type_importer.Visit(type); + if (extra_dependencies != nullptr) { + for (const Type& t : *extra_dependencies) { type_importer.Visit(t); } + } + if (!type_importer.imports().empty()) { EndLine(); for (const string& s : type_importer.imports()) { Append("import ").Append(s).Append(";").EndLine(); } } - return BeginInnerType(type, modifiers); + return BeginInnerType(type, modifiers, javadoc); } -SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) { +SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc) { GenericNamespace* generic_namespace = PushGenericNamespace(modifiers); generic_namespace->Visit(type); EndLine(); - WriteDoc(type.description()); + if (javadoc != nullptr) { + WriteJavadoc(*javadoc); + } if (!type.annotations().empty()) { WriteAnnotations(type.annotations()); } @@ -200,14 +221,15 @@ SourceWriter& SourceWriter::EndType() { return *this; } -SourceWriter& SourceWriter::WriteFields(const std::list<Variable>& fields, - int modifiers) { - EndLine(); - for (const Variable& v : fields) { - WriteModifiers(modifiers); - AppendType(v.type()).Append(" ").Append(v.name()).Append(";"); - EndLine(); +SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc) { + // If present, write field javadoc only as one brief line + if (javadoc != nullptr && !javadoc->brief().empty()) { + Append("/** ").Append(javadoc->brief()).Append(" */").EndLine(); } + WriteModifiers(modifiers); + AppendType(field.type()).Append(" ").Append(field.name()).Append(";"); + EndLine(); return *this; } @@ -228,39 +250,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) { return *this; } -SourceWriter& SourceWriter::WriteDoc(const string& description, - const string& return_description, const std::list<Variable>* parameters) { - if (description.empty() && return_description.empty() - && (parameters == nullptr || parameters->empty())) { - return *this; // no doc to write - } +SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) { + Append("/**").Prefix(" * ").EndLine(); bool do_line_break = false; - Append("/**").EndLine().Prefix(" * "); - if (!description.empty()) { - Write(description).EndLine(); + if (!javadoc.brief().empty()) { + Write(javadoc.brief()).EndLine(); do_line_break = true; } - if (parameters != nullptr && !parameters->empty()) { + if (!javadoc.details().empty()) { if (do_line_break) { - EndLine(); - do_line_break = false; - } - for (const Variable& v : *parameters) { - Append("@param ").Append(v.name()); - if (!v.description().empty()) { - Append(" ").Write(v.description()); - } - EndLine(); + Append("<p>").EndLine(); } + Write(javadoc.details()).EndLine(); + do_line_break = true; } - if (!return_description.empty()) { + if (!javadoc.tags().empty()) { if (do_line_break) { EndLine(); - do_line_break = false; } - Append("@return ").Write(return_description).EndLine(); + for (const auto& p : javadoc.tags()) { + Append("@" + p.first); + if (!p.second.empty()) { + Append(" ").Write(p.second); + } + EndLine(); + } } - return Prefix("").Append(" **/").EndLine(); + return Prefix("").Append(" */").EndLine(); } SourceWriter& SourceWriter::WriteAnnotations( @@ -311,20 +327,19 @@ void SourceWriter::PopGenericNamespace() { void SourceWriter::TypeVisitor::Visit(const Type& type) { DoVisit(type); for (const Type& t : type.parameters()) { - DoVisit(t); + Visit(t); } for (const Annotation& t : type.annotations()) { DoVisit(t); } for (const Type& t : type.supertypes()) { - DoVisit(t); + Visit(t); } } void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC - && !type.IsWildcard() + if (type.kind() == Type::GENERIC && !type.unknown() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -333,7 +348,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.package() + '.' + type.name()); + imports_.insert(type.full_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index f011acd30a..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -93,25 +93,22 @@ class SourceWriter { // This method appends a new opening brace to the current data and indent the // next lines according to Google Java Style Guide. The block can optionally // be preceded by an expression (e.g. Append("if(true)").BeginBlock();) - SourceWriter& BeginBlock() { - return Append(newline_ ? "{" : " {").EndLine().Indent(2); - } + SourceWriter& BeginBlock(const string& expr = ""); // Ends the current block of source code. // // This method appends a new closing brace to the current data and outdent the // next lines back to the margin used before BeginBlock() was invoked. - SourceWriter& EndBlock() { - return Indent(-2).Append("}").EndLine(); - } + SourceWriter& EndBlock(); // Begins to write a method. // // This method outputs the signature of the Java method from the data passed - // in the 'method' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // method. - SourceWriter& BeginMethod(const Method& method, int modifiers = 0); + // in the 'method' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this method and, optionally, + // a Javadoc. + SourceWriter& BeginMethod(const Method& method, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current method. // @@ -122,22 +119,24 @@ class SourceWriter { // Begins to write the main type of a source file. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the access scope of this type and, optionally, + // a Javadoc. // - // If not null, all types found in the 'dependencies' list will be imported - // before declaring the new type. - SourceWriter& BeginType(const Type& clazz, - const std::list<Type>* dependencies, int modifiers = 0); + // If not null, all types found in the 'extra_dependencies' list will be + // imported before declaring the new type. + SourceWriter& BeginType(const Type& clazz, int modifiers, + const std::list<Type>* extra_dependencies = nullptr, + const Javadoc* javadoc = nullptr); // Begins to write a new inner type. // // This method outputs the declaration of the Java type from the data passed - // in the 'type' parameter and starts a new block. Additionnal modifiers can - // also be passed in parameter to define the accesses and the scope of this - // type. - SourceWriter& BeginInnerType(const Type& type, int modifiers = 0); + // in the 'type' parameter and starts a new block. Modifiers are also passed + // in parameter to define the accesses and the scope of this type and, + // optionally, a Javadoc. + SourceWriter& BeginInnerType(const Type& type, int modifiers, + const Javadoc* javadoc = nullptr); // Ends the current type. // @@ -145,13 +144,13 @@ class SourceWriter { // BeginType() or BeginInnerType() prior to this. SourceWriter& EndType(); - // Writes a list of variables as fields of a type. + // Writes a variable as fields of a type. // // This method must be called within the definition of a type (see BeginType() - // or BeginInnerType()). Additional modifiers can also be passed in parameter - // to define the accesses and the scope of those fields. - SourceWriter& WriteFields(const std::list<Variable>& fields, - int modifiers = 0); + // or BeginInnerType()). Modifiers are also be passed in parameter to define + // the accesses and the scope of this field and, optionally, a Javadoc. + SourceWriter& WriteField(const Variable& field, int modifiers, + const Javadoc* javadoc = nullptr); protected: virtual void DoAppend(const StringPiece& str) = 0; @@ -207,9 +206,7 @@ class SourceWriter { std::stack<GenericNamespace*> generic_namespaces_; SourceWriter& WriteModifiers(int modifiers); - SourceWriter& WriteDoc(const string& description, - const string& return_description = "", - const std::list<Variable>* parameters = nullptr); + SourceWriter& WriteJavadoc(const Javadoc& javadoc); SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations); SourceWriter& WriteGenerics(const std::list<const Type*>& generics); GenericNamespace* PushGenericNamespace(int modifiers); diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 4bce2fea70..8bd42d9d0e 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -250,7 +250,7 @@ TEST(StreamTest, Types) { .AppendType(generic).Append(", ") .AppendType(Type::ListOf(generic)).Append(", ") .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ") - .AppendType(Type::ListOf(Type::Generic())); + .AppendType(Type::ListOf(Type::Wildcard())); const char* expected = "int, String, T, List<T>, List<Iterable<T>>, List<?>"; @@ -282,7 +282,7 @@ TEST(WriteType, SimpleClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -300,7 +300,7 @@ TEST(WriteType, SimpleClassWithDependencies) { deps.push_back(Type::Class("SamePackageType", "org.tensorflow")); deps.push_back(Type::Class("NoPackageType")); - writer.BeginType(clazz, &deps, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, &deps).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -313,18 +313,21 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - clazz.description("This class has a\n<p>\nmultiline description."); + Javadoc clazz_doc; + clazz_doc.brief("Javadoc test") + .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType(); const char* expected = "package org.tensorflow;\n\n" "/**\n" - " * This class has a\n" + " * Javadoc test\n" " * <p>\n" + " * This is a\n" " * multiline description.\n" " **/\n" "@Bean\n" @@ -339,7 +342,7 @@ TEST(WriteType, ParameterizedClass) { clazz.add_parameter(Type::Generic("T")); clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number"))); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -358,7 +361,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) { clazz.add_supertype(Type::Interface("Runnable")); clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t)); - writer.BeginType(clazz, nullptr, PUBLIC).EndType(); + writer.BeginType(clazz, PUBLIC).EndType(); const char* expected = "package org.tensorflow;\n\n" @@ -372,24 +375,24 @@ TEST(WriteType, ParameterizedClassFields) { Type clazz = Type::Class("Test", "org.tensorflow"); Type type_t = Type::Generic("T").add_supertype(Type::Class("Number")); clazz.add_parameter(type_t); - std::list<Variable> static_fields; - static_fields.push_back(Variable::Create("field1", Type::Class("String"))); - std::list<Variable> member_fields; - member_fields.push_back(Variable::Create("field2", Type::Class("String"))); - member_fields.push_back(Variable::Create("field3", type_t)); - - writer.BeginType(clazz, nullptr, PUBLIC) - .WriteFields(static_fields, STATIC | PUBLIC | FINAL) - .WriteFields(member_fields, PRIVATE) + Variable field1 = Variable::Create("field1", Type::Class("String")); + Variable field2 = Variable::Create("field2", Type::Class("String")); + Variable field3 = Variable::Create("field3", type_t); + Javadoc field3_doc; + field3_doc.brief("This variable is documented"); + + writer.BeginType(clazz, PUBLIC) + .WriteField(field1, STATIC | PUBLIC | FINAL) + .WriteField(field2, PRIVATE) + .WriteField(field3, PRIVATE, &field3_doc) .EndType(); const char* expected = "package org.tensorflow;\n\n" "public class Test<T extends Number> {\n" - " \n" " public static final String field1;\n" - " \n" " private String field2;\n" + " /** This variable is documented */\n" " private T field3;\n" "}\n"; ASSERT_STREQ(expected, writer.str().data()); @@ -400,7 +403,7 @@ TEST(WriteType, SimpleInnerClass) { Type clazz = Type::Class("Test", "org.tensorflow"); Type inner_class = Type::Class("InnerTest"); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC) .EndType() .EndType(); @@ -423,7 +426,7 @@ TEST(WriteType, StaticParameterizedInnerClass) { Type inner_class = Type::Class("InnerTest"); inner_class.add_parameter(type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginInnerType(inner_class, PUBLIC | STATIC) .EndType() .EndType(); @@ -443,7 +446,7 @@ TEST(WriteMethod, SimpleMethod) { Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC).EndMethod() .EndType(); @@ -461,13 +464,15 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - method.description("This method has a\n<p>\nmultiline description."); + Javadoc method_doc; + method_doc.brief("Javadoc test") + .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") .attributes("\"rawtypes\"")); - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC).EndMethod() + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc).EndMethod() .EndType(); const char* expected = @@ -475,8 +480,9 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { "public class Test {\n" " \n" " /**\n" - " * This method has a\n" + " * Javadoc test\n" " * <p>\n" + " * This method has a\n" " * multiline description.\n" " **/\n" " @Override\n" @@ -490,16 +496,18 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { TEST(WriteMethod, DocumentedMethodWithArguments) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); + Variable reverse = Variable::Create("reverse", Type::Boolean()); Method method = Method::Create("boolToInt", Type::Int()); - method.description("Converts a boolean to an int"); - method.return_description("int value for this boolean"); method.add_argument(Variable::Create("b", Type::Boolean())); - Variable reverse = Variable::Create("reverse", Type::Boolean()); - reverse.description("if true, value is reversed"); method.add_argument(reverse); - - writer.BeginType(clazz, nullptr, PUBLIC) - .BeginMethod(method, PUBLIC) + Javadoc method_doc; + method_doc.brief("Converts a boolean to an int") + .details("This method will convert\na boolean to an int") + .add_param_tag(reverse.name(), "if true, value is reversed") + .add_tag("return", "int value for this boolean"); + + writer.BeginType(clazz, PUBLIC) + .BeginMethod(method, PUBLIC, &method_doc) .Append("if (b && !reverse)") .BeginBlock() .Append("return 1;").EndLine() @@ -514,8 +522,10 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " \n" " /**\n" " * Converts a boolean to an int\n" + " * <p>\n" + " * This method will convert\n" + " * a boolean to an int\n" " * \n" - " * @param b\n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" " **/\n" @@ -536,7 +546,7 @@ TEST(WriteMethod, ParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC) .Append("return null;").EndLine() .EndMethod() @@ -560,7 +570,7 @@ TEST(WriteMethod, StaticParameterizedMethod) { clazz.add_parameter(type_t); Method method = Method::Create("doNothing", type_t); - writer.BeginType(clazz, nullptr, PUBLIC) + writer.BeginType(clazz, PUBLIC) .BeginMethod(method, PUBLIC | STATIC) .Append("return null;").EndLine() .EndMethod() diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index a6650fc4ea..1e7899cf7a 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -1,9 +1,11 @@ # -*- Python -*- -load("//tensorflow:tensorflow.bzl", - "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts") +load( + "//tensorflow:tensorflow.bzl", + "tf_binary_additional_srcs", + "tf_cc_binary", + "tf_copts", +) # Given a list of "ops_libs" (a list of files in the core/ops directory # without their .cc extensions), generate Java wrapper code for all operations @@ -27,16 +29,31 @@ def tf_java_op_gen_srcjar(name, ops_libs_pkg="//tensorflow/core", out_dir="ops/", out_src_dir="src/main/java/", + api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files + srcs = api_def_srcs[:] # Construct an op generator binary for each ops library. for ops_lib in ops_libs: gen_lib = ops_lib[:ops_lib.rfind("_")] out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + " $$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) + tf_cc_binary( name=out_gen_tool, copts=tf_copts(), @@ -48,7 +65,8 @@ def tf_java_op_gen_srcjar(name, gen_cmds += ["$(location :" + out_gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package] + " --base_package=" + gen_base_package + + " " + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" @@ -57,6 +75,7 @@ def tf_java_op_gen_srcjar(name, gen_tools += tf_binary_additional_srcs() native.genrule( name=name, + srcs=srcs, outs=[gen_srcjar], tools=gen_tools, cmd="&&".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java new file mode 100644 index 0000000000..90285ec669 --- /dev/null +++ b/tensorflow/java/src/gen/resources/license.snippet.java @@ -0,0 +1,14 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ |