From 7e80197f020895fea41eda36b08135b747a9a4f1 Mon Sep 17 00:00:00 2001 From: "karl@kubx.ca" Date: Fri, 6 Apr 2018 08:56:54 -0400 Subject: Improve Javadoc and include first code review --- tensorflow/java/BUILD | 23 +- tensorflow/java/src/gen/cc/java_defs.h | 12 +- tensorflow/java/src/gen/cc/op_gen_main.cc | 48 +-- tensorflow/java/src/gen/cc/op_generator.cc | 224 ++++++----- tensorflow/java/src/gen/cc/op_generator.h | 25 +- tensorflow/java/src/gen/cc/op_parser.cc | 417 --------------------- tensorflow/java/src/gen/cc/op_parser.h | 137 ------- tensorflow/java/src/gen/cc/op_specs.cc | 390 +++++++++++++++++++ tensorflow/java/src/gen/cc/op_specs.h | 152 ++++++++ tensorflow/java/src/gen/cc/source_writer.cc | 2 +- tensorflow/java/src/gen/cc/source_writer.h | 2 +- tensorflow/java/src/gen/cc/source_writer_test.cc | 20 +- tensorflow/java/src/gen/gen_ops.bzl | 68 ++-- .../java/src/gen/resources/license.java.snippet | 14 + .../java/src/gen/resources/license.snippet.java | 14 - 15 files changed, 774 insertions(+), 774 deletions(-) delete mode 100644 tensorflow/java/src/gen/cc/op_parser.cc delete mode 100644 tensorflow/java/src/gen/cc/op_parser.h create mode 100644 tensorflow/java/src/gen/cc/op_specs.cc create mode 100644 tensorflow/java/src/gen/cc/op_specs.h create mode 100644 tensorflow/java/src/gen/resources/license.java.snippet delete mode 100644 tensorflow/java/src/gen/resources/license.snippet.java diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 635a4e807d..17566e1a9c 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -68,9 +68,13 @@ filegroup( ], ) +# Build the gen tool as a library, as it will be linked to a core/ops binary +# files before making it an executable. tf_java_op_gen_srcjar( name = "java_op_gen_sources", - api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], + api_def_srcs = [ + "//tensorflow/core/api_def:base_api_def", + ], gen_base_package = "org.tensorflow.op", gen_tool = "java_op_gen_tool", ops_libs = [ @@ -95,30 +99,17 @@ tf_java_op_gen_srcjar( ], ) -# Build the gen tool as a library, as it will be linked to a core/ops binary -# file before making it an executable. See tf_java_op_gen_srcjar(). -cc_library( - name = "java_op_gen_tool", - srcs = [ - "src/gen/cc/op_gen_main.cc", - ], - copts = tf_copts(), - deps = [ - ":java_op_gen_lib", - ], -) - cc_library( name = "java_op_gen_lib", srcs = [ "src/gen/cc/op_generator.cc", - "src/gen/cc/op_parser.cc", + "src/gen/cc/op_specs.cc", "src/gen/cc/source_writer.cc", ], hdrs = [ "src/gen/cc/java_defs.h", "src/gen/cc/op_generator.h", - "src/gen/cc/op_parser.h", + "src/gen/cc/op_specs.h", "src/gen/cc/source_writer.h", ], copts = tf_copts(), diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 2065477f58..81ac67eb2f 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -230,12 +230,12 @@ class Javadoc { return Javadoc(brief); } const string& brief() const { return brief_; } - const string& details() const { return description_; } - Javadoc& details(const string description) { - description_ = description; + const string& details() const { return details_; } + Javadoc& details(const string& details) { + details_ = details; return *this; } - const std::list> tags() const { return tags_; } + const std::list>& tags() const { return tags_; } Javadoc& add_tag(const string& tag, const string& text) { tags_.push_back(std::make_pair(tag, text)); return *this; @@ -246,7 +246,7 @@ class Javadoc { private: string brief_; - string description_; + string details_; std::list> tags_; explicit Javadoc(const string& brief) : brief_(brief) {} diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index 015200023f..458141b877 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -36,55 +36,41 @@ const char kUsageHeader[] = "Operation wrappers are generated under the path specified by the " "'--output_dir' argument. This path can be absolute or relative to the\n" "current working directory and will be created if it does not exists.\n\n" - "The '--lib_name' argument is used to classify the set of operations. If " - "the chosen name contains more than one word, it must be provided in \n" - "snake_case. This value is declined into other meaningful names, such as " - "the group and package of the generated operations. For example,\n" - "'--lib_name=my_lib' generates the operations under the " - "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n" - "group.\n\n" - "Note that the operator group assigned to the generated wrappers is just " - "an annotation tag at this stage. Operations will not be available " - "through\n" - "the 'org.tensorflow.op.Ops' API as a group until the generated classes " - "are compiled using an appropriate annotation processor.\n\n" + "Note that the operations will not be available through the " + "'org.tensorflow.op.Ops' API until the generated classes are compiled\n" + "using an appropriate annotation processor.\n\n" "The '--base_package' overrides the default parent package under which " "the generated subpackage and classes are to be located.\n\n" - "Finally, a list of directories of API proto definitions can be provided " - "to override default values found in the ops definitions, ordered by\n" - "priority (the last having precedence over the first).\n\n"; + "Finally, the `--api_dirs` argument takes a list of comma-seperated " + "directories of API definitions can be provided to override default\n" + "values found in the ops definitions. Directories are ordered by priority " + "(the last having precedence over the first).\n\n"; } // namespace java } // namespace tensorflow int main(int argc, char* argv[]) { - tensorflow::string lib_name; tensorflow::string output_dir; tensorflow::string base_package = "org.tensorflow.op"; + tensorflow::string api_dirs_str; std::vector flag_list = { tensorflow::Flag("output_dir", &output_dir, "Root directory into which output files are generated"), - tensorflow::Flag( - "lib_name", &lib_name, - "A name, in snake_case, used to classify this set of operations"), - tensorflow::Flag( - "base_package", &base_package, - "Package parent to the generated subpackage and classes")}; + tensorflow::Flag("base_package", &base_package, + "Package parent to the generated subpackage and classes"), + tensorflow::Flag("api_dirs", &api_dirs_str, + "List of directories that contains the ops api definitions")}; tensorflow::string usage = tensorflow::java::kUsageHeader; usage += tensorflow::Flags::Usage(argv[0], flag_list); bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); tensorflow::port::InitMain(usage.c_str(), &argc, &argv); - QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage; - std::vector api_dirs; - if (argc > 1) { - api_dirs = tensorflow::str_util::Split(argv[1], ",", - tensorflow::str_util::SkipEmpty()); - } + QCHECK(parsed_flags_ok && !output_dir.empty()) << usage; + std::vector api_dirs = tensorflow::str_util::Split( + api_dirs_str, ",", tensorflow::str_util::SkipEmpty()); tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs); tensorflow::OpList ops; tensorflow::OpRegistry::Global()->Export(false, &ops); - tensorflow::Status status = generator.Run(ops, lib_name); - TF_QCHECK_OK(status); + TF_CHECK_OK(generator.Run(ops)); return 0; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index c9b57f5706..c32ad3b109 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -27,15 +28,15 @@ limitations under the License. #include "tensorflow/core/framework/op_gen_lib.h" #include "tensorflow/java/src/gen/cc/java_defs.h" #include "tensorflow/java/src/gen/cc/source_writer.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" #include "tensorflow/java/src/gen/cc/op_generator.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { namespace { const char* kLicenseSnippet = - "tensorflow/java/src/gen/resources/license.snippet.java"; + "tensorflow/java/src/gen/resources/license.java.snippet"; const std::map kPrimitiveAttrTypes = { { "Boolean", Type::Boolean() }, @@ -66,34 +67,34 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, } // Don't pay attention to duplicate types in the dependency list, they will // be filtered out by the SourceWriter. - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { out->push_back(input.var().type()); if (input.iterable()) { out->push_back(Type::Class("Operands", "org.tensorflow.op")); } } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { out->push_back(output.var().type()); if (output.iterable()) { out->push_back(Type::Class("Arrays", "java.util")); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); if (attribute.var().type().name() == "Class") { out->push_back(Type::Enum("DataType", "org.tensorflow")); } } - for (const OpSpec::Operand& option : op.options()) { - out->push_back(option.var().type()); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + out->push_back(optional_attribute.var().type()); } } -void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, +void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, SourceWriter* writer) { string var = optional ? "opts." + attr.var().name() : attr.var().name(); if (attr.iterable()) { - const Type& type = attr.data_type(); + const Type& type = attr.type(); std::map::const_iterator it = kPrimitiveAttrTypes.find(type.name()); if (it != kPrimitiveAttrTypes.end()) { @@ -107,11 +108,11 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, .Append(array + "[i] = " + var + ".get(i);") .EndLine() .EndBlock() - .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array) + .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array) .Append(");") .EndLine(); } else { - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var) + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var) .Append(".toArray(new ") .AppendType(type) .Append("[" + var + ".size()]));") @@ -119,7 +120,7 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional, } } else { Type type = attr.var().type(); - writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", "); + writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", "); if (type.name() == "Class") { writer->Append("DataType.fromClass(" + attr.var().name() + "));"); } else { @@ -139,26 +140,26 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); factory.add_argument(scope); factory_doc.add_param_tag(scope.name(), "Current graph scope"); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { factory.add_argument(input.var()); factory_doc.add_param_tag(input.var().name(), input.description()); } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { factory.add_argument(attribute.var()); factory_doc.add_param_tag(attribute.var().name(), attribute.description()); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { factory.add_argument(Variable::Varargs("options", Type::Class("Options"))); factory_doc.add_param_tag("options", "carries optional attributes values"); } factory_doc.add_tag("return", "a new instance of " + op_class.name()); writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc); writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" - + op.graph_name() + "\", scope.makeOpName(\"" + + op.graph_op_name() + "\", scope.makeOpName(\"" + op_class.name() + "\"));"); writer->EndLine(); - for (const OpSpec::Operand& input : op.inputs()) { + for (const ArgumentSpec& input : op.inputs()) { if (input.iterable()) { writer->Append("opBuilder.addInputList(Operands.asOutputs(" + input.var().name() + "));"); @@ -169,15 +170,15 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class, writer->EndLine(); } } - for (const OpSpec::Operand& attribute : op.attributes()) { + for (const AttributeSpec& attribute : op.attributes()) { WriteSetAttrDirective(attribute, false, writer); } - if (!op.options().empty()) { + if (!op.optional_attributes().empty()) { writer->BeginBlock("if (options != null)") .BeginBlock("for (Options opts : options)"); - for (const OpSpec::Operand& option : op.options()) { - writer->BeginBlock("if (opts." + option.var().name() + " != null)"); - WriteSetAttrDirective(option, true, writer); + for (const AttributeSpec& attribute : op.optional_attributes()) { + writer->BeginBlock("if (opts." + attribute.var().name() + " != null)"); + WriteSetAttrDirective(attribute, true, writer); writer->EndBlock(); } writer->EndBlock().EndBlock(); @@ -195,8 +196,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, .add_argument( Variable::Create("operation", Type::Class("Operation", "org.tensorflow"))); - for (const OpSpec::Operand& output : op.outputs()) { - if (output.iterable() && !output.data_type().unknown()) { + for (const ArgumentSpec& output : op.outputs()) { + if (output.iterable() && !output.type().unknown()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -208,15 +209,15 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, if (op.outputs().size() > 0) { writer->Append("int outputIdx = 0;") .EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { if (output.iterable()) { string var_length = output.var().name() + "Length"; writer->Append("int " + var_length) - .Append(" = operation.outputListLength(\"" + output.graph_name() + .Append(" = operation.outputListLength(\"" + output.op_def_name() + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.data_type().unknown()) { + if (!output.type().unknown()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -236,18 +237,19 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, } void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), Type::Class("Options")) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = + Method::Create(attribute.var().name(), Type::Class("Options")) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc) - .Append("return new Options()." + option.var().name() + "(" - + option.var().name() + ");") + .Append("return new Options()." + attribute.var().name() + "(" + + attribute.var().name() + ");") .EndLine() .EndMethod(); } - for (const OpSpec::Operand& output : op.outputs()) { + for (const ArgumentSpec& output : op.outputs()) { Method getter = Method::Create(output.var().name(), output.var().type()); Javadoc getter_doc = Javadoc::Create(output.description()); writer->BeginMethod(getter, PUBLIC, &getter_doc) @@ -259,12 +261,12 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) { void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { - OpSpec::Operand output = op.outputs().front(); + ArgumentSpec output = op.outputs().front(); if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.data_type().unknown(); + bool cast2obj = output.type().unknown(); Type return_type = Type::Class("Output", "org.tensorflow") - .add_parameter(cast2obj ? Type::Class("Object") : output.data_type()); + .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) .add_annotation(Annotation::Create("Override")); if (cast2obj) { @@ -283,10 +285,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, } else if (mode == SINGLE_LIST_OUTPUT) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.data_type().unknown()) { + if (output.type().unknown()) { operand.add_parameter(Type::Class("Object")); } else { - operand.add_parameter(output.data_type()); + operand.add_parameter(output.type()); } Type return_type = Type::Interface("Iterator", "java.util") .add_parameter(operand); @@ -308,57 +310,119 @@ void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) { Javadoc options_doc = Javadoc::Create( "Class holding optional attributes of this operation"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); - for (const OpSpec::Operand& option : op.options()) { - Method setter = Method::Create(option.var().name(), options_class) - .add_argument(option.var()); + for (const AttributeSpec& attribute : op.optional_attributes()) { + Method setter = Method::Create(attribute.var().name(), options_class) + .add_argument(attribute.var()); Javadoc setter_doc = Javadoc::Create() - .add_param_tag(option.var().name(), option.description()); + .add_param_tag(attribute.var().name(), attribute.description()); writer->BeginMethod(setter, PUBLIC, &setter_doc) - .Append("this." + option.var().name() + " = " + option.var().name() - + ";") + .Append("this." + attribute.var().name() + " = " + + attribute.var().name() + ";") .EndLine() .Append("return this;") .EndLine() .EndMethod(); } writer->EndLine(); - for (const OpSpec::Operand& option : op.options()) { - writer->WriteField(option.var(), PRIVATE); + for (const AttributeSpec& optional_attribute : op.optional_attributes()) { + writer->WriteField(optional_attribute.var(), PRIVATE); } Method constructor = Method::ConstructorFor(options_class); writer->BeginMethod(constructor, PRIVATE).EndMethod(); writer->EndType(); } -void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint, - SourceWriter* writer) { +inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { + return Type::Class(endpoint.name(), + base_package + "." + str_util::Lowercase(endpoint.package())); +} + +void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, + const string& base_package, const string& output_dir, Env* env) { + Type op_class(ClassOf(endpoint, base_package) + .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); + Javadoc op_javadoc(endpoint.javadoc()); + + // implement Operand (or Iterable) if the op has only one output RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { - mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT; + const ArgumentSpec& output = op.outputs().front(); + Type operand_type(output.type().unknown() ? + Type::Class("Object") : output.type()); + Type operand_inf(Type::Interface("Operand", "org.tensorflow") + .add_parameter(operand_type)); + if (output.iterable()) { + mode = SINGLE_LIST_OUTPUT; + op_class.add_supertype(Type::IterableOf(operand_inf)); + } else { + mode = SINGLE_OUTPUT; + op_class.add_supertype(operand_inf); + } + } + // declare all outputs generics at the op class level + std::set generics; + for (const ArgumentSpec& output : op.outputs()) { + if (output.type().kind() == Type::GENERIC && !output.type().unknown() + && generics.find(output.type().name()) == generics.end()) { + op_class.add_parameter(output.type()); + op_javadoc.add_param_tag("<" + output.type().name() + ">", + "data type of output {@code " + output.var().name() + "}"); + generics.insert(output.type().name()); + } + } + // handle endpoint deprecation + if (endpoint.deprecated()) { + op_class.add_annotation(Annotation::Create("Deprecated")); + string explanation; + if (!op.endpoints().front().deprecated()) { + explanation = "use {@link " + + ClassOf(op.endpoints().front(), base_package).full_name() + + "} instead"; + } else { + explanation = op.deprecation_explanation(); + } + op_javadoc.add_tag("deprecated", explanation); } + // expose the op in the Ops Graph API only if it is visible + if (!op.hidden()) { + op_class.add_annotation( + Annotation::Create("Operator", "org.tensorflow.op.annotation") + .attributes("group = \"" + endpoint.package() + "\"")); + } + // create op class file + string op_dir = io::JoinPath(output_dir, + str_util::StringReplace(op_class.package(), ".", "/", true)); + if (!env->FileExists(op_dir).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + } + std::unique_ptr op_file; + TF_CHECK_OK(env->NewWritableFile( + io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + + // render endpoint source code + SourceFileWriter writer(op_file.get()); std::list dependencies; CollectOpDependencies(op, mode, &dependencies); - const Type& op_class = endpoint.type(); - writer->WriteFromFile(kLicenseSnippet) + writer.WriteFromFile(kLicenseSnippet) .EndLine() .Append("// This file is machine generated, DO NOT EDIT!") .EndLine() .EndLine() - .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc()); - if (!op.options().empty()) { - RenderOptionsClass(op, writer); + .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc); + if (!op.optional_attributes().empty()) { + RenderOptionsClass(op, &writer); } - RenderFactoryMethod(op, op_class, writer); - RenderGettersAndSetters(op, writer); + RenderFactoryMethod(op, op_class, &writer); + RenderGettersAndSetters(op, &writer); if (mode != DEFAULT) { - RenderInterfaceImpl(op, mode, writer); + RenderInterfaceImpl(op, mode, &writer); } - writer->EndLine(); - for (const OpSpec::Operand& output : op.outputs()) { - writer->WriteField(output.var(), PRIVATE); + writer.EndLine(); + for (const ArgumentSpec& output : op.outputs()) { + writer.WriteField(output.var(), PRIVATE); } - RenderConstructor(op, op_class, writer); - writer->EndType(); + RenderConstructor(op, op_class, &writer); + writer.EndType(); } } // namespace @@ -369,8 +433,7 @@ OpGenerator::OpGenerator(const string& base_package, const string& output_dir, env_(env) { } -Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { - LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations"; +Status OpGenerator::Run(const OpList& op_list) { ApiDefMap api_map(op_list); if (!api_dirs_.empty()) { // Only load api files that correspond to the requested "op_list" @@ -388,37 +451,14 @@ Status OpGenerator::Run(const OpList& op_list, const string& lib_name) { for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { - Status status = GenerateOp(op_def, *api_def, lib_name); - if (status != Status::OK()) { - LOG(ERROR) << "Fail to generate Java wrapper for operation \"" - << op_def.name() << "\""; + OpSpec op(OpSpec::Create(op_def, *api_def)); + for (const EndpointSpec& endpoint : op.endpoints()) { + GenerateOp(op, endpoint, base_package_, output_dir_, env_); } } } return Status::OK(); } -Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name) { - std::unique_ptr op; - OpParser op_parser(op_def, api_def, lib_name, base_package_); - op_parser.Parse(&op); - for (const OpSpec::Endpoint& endpoint : op->endpoints()) { - string package_path = io::JoinPath(output_dir_, - str_util::StringReplace(endpoint.type().package(), ".", "/", true)); - if (!env_->FileExists(package_path).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path)); - } - string file_path = - io::JoinPath(package_path, endpoint.type().name() + ".java"); - std::unique_ptr file; - TF_CHECK_OK(env_->NewWritableFile(file_path, &file)); - - SourceFileWriter writer(file.get()); - RenderEndpoint(*op, endpoint, &writer); - } - return Status::OK(); -} - } // namespace java } // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 19d8db95fb..06b08e852a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -23,36 +23,33 @@ limitations under the License. #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { namespace java { // A generator of Java operation wrappers. // -// Such generator is normally ran only once per executable, outputting -// wrappers for the all registered operations it has been compiled with. -// Nonetheless, it is designed to support multiple runs, giving a different -// list of operations on each cycle. +// This generator takes a list of ops definitions in input and outputs +// a Java Op wrapper for each of them in the provided directory. The same +// generator instance can be invoked multiple times with a different list of +// ops definitions. class OpGenerator { public: OpGenerator(const string& base_package, const string& output_dir, const std::vector& api_dirs, Env* env = Env::Default()); - virtual ~OpGenerator() = default; // Generates wrappers for the given list of 'ops'. // // Output files are generated in //, - // where 'lib_package' is derived from 'lib_name'. - Status Run(const OpList& op_list, const string& lib_name); + // where 'lib_package' is derived from ops endpoints. + Status Run(const OpList& op_list); private: - string base_package_; - string output_dir_; - std::vector api_dirs_; + const string base_package_; + const string output_dir_; + const std::vector api_dirs_; Env* env_; - - Status GenerateOp(const OpDef& op_def, const ApiDef& api_def, - const string& lib_name); }; } // namespace java diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc deleted file mode 100644 index 0541e343d8..0000000000 --- a/tensorflow/java/src/gen/cc/op_parser.cc +++ /dev/null @@ -1,417 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/java/src/gen/cc/op_parser.h" - -namespace tensorflow { -namespace java { -namespace { - -string SnakeToCamelCase(const string& str, bool upper = false) { - string result; - bool cap = upper; - for (string::const_iterator it = str.begin(); it != str.end(); ++it) { - const char c = *it; - if (c == '_') { - cap = true; - } else if (cap) { - result += toupper(c); - cap = false; - } else { - result += c; - } - } - return result; -} - -bool IsRealNumber(DataType type) { - for (DataType dt : RealNumberTypes()) { - if (type == dt) { - return true; - } - } - return false; -} - -bool IsRealNumbers(const AttrValue& values) { - if (values.has_list()) { - for (int i = 0; i < values.list().type_size(); ++i) { - if (!IsRealNumber(values.list().type(i))) { - return false; - } - } - return true; - } - return IsRealNumber(values.type()); -} - -string ParseDocumentation(const string& text) { - std::stringstream javadoc_text; - string::const_iterator c_iter = text.cbegin(); - bool code = false; - bool emphasis = false; - bool list = false; - while (c_iter != text.cend()) { - char c = *c_iter++; - int count = 1; - switch (c) { - case '\n': - if (!code) { - // consumes all subsequent newlines, if there are more than one, - // then there are two choices: - // - if the next line starts with an asterisk, we are enumerating - // a list of items - // - otherwise, we are starting a new paragraph - for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {} - if (c_iter != text.cend()) { - if (count > 1) { - if (*c_iter != '*' && list) { - javadoc_text << "\n\n"; - list = false; - } else if (*c_iter == '*' && !list) { - javadoc_text << "\n
    \n
  • "; - list = true; - c_iter++; - } else { - javadoc_text << "\n

    \n"; - } - } else if (list && *c_iter == '*') { - javadoc_text << "

  • \n
  • "; - c_iter++; - } else { - javadoc_text << '\n'; - } - } - } - break; - case '`': - // consumes all subsequent backquotes, those are use enclose code. - // if there are more than 3, we are dealing with a pre-formatted block, - // otherwise it is a single-line code snippet - for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {} - if (count >= 3) { - javadoc_text << (code ? "\n}" : "
    {@code\n");
    -      } else {
    -        javadoc_text << (code ? "}" : "{@code ");
    -      }
    -      code = !code;
    -      break;
    -    case '*':
    -      if (!code) {
    -        // consumes all subsequent asterisks, if there are more than one, then
    -        // we put the text in bold, otherwise in italic
    -        for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
    -        if (count > 1) {
    -          javadoc_text << (emphasis ? "" : "");
    -        } else {
    -          javadoc_text << (emphasis ? "" : "");
    -        }
    -        emphasis = !emphasis;
    -      } else {
    -        javadoc_text << '*';
    -      }
    -      break;
    -    default:
    -      javadoc_text << c;
    -      break;
    -    }
    -  }
    -  return javadoc_text.str();
    -}
    -
    -}  // namespace
    -
    -OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
    -    const string& lib_name, const string& base_package)
    -  : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
    -    base_package_(base_package) {
    -}
    -
    -void OpParser::Parse(std::unique_ptr* op_ptr) {
    -  visited_attrs_.clear();
    -  next_generic_ = 'T';
    -  op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
    -  for (const string& next_input_name : op_api_.arg_order()) {
    -    for (int i = 0; i < op_def_.input_arg().size(); ++i) {
    -      if (op_def_.input_arg(i).name() == next_input_name) {
    -        ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
    -        break;
    -      }
    -    }
    -  }
    -  for (int i = 0; i < op_def_.attr().size(); ++i) {
    -    ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
    -  }
    -  for (int i = 0; i < op_def_.output_arg().size(); ++i) {
    -    ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
    -  }
    -  BuildEndpoints(op_ptr->get());
    -}
    -
    -void OpParser::BuildEndpoints(OpSpec* op) {
    -  Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
    -    .details(ParseDocumentation(op_api_.description()));
    -  std::vector op_supertypes;
    -  op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
    -  std::map op_generics;
    -  for (const OpSpec::Operand& output : op->outputs()) {
    -    // declare generic output parameters at the Op class level
    -    const Type& data_type = output.data_type();
    -    if (data_type.kind() == Type::GENERIC && !data_type.unknown()
    -        && op_generics.find(data_type.name()) == op_generics.end()) {
    -      op_generics.insert(std::make_pair(data_type.name(), &data_type));
    -      op_doc.add_param_tag("<" + data_type.name() + ">",
    -          "data type of output '" + output.var().name() + "'");
    -    }
    -    // implement the Op as an (iteration of) Operand if it has only one output
    -    if (op->outputs().size() == 1) {
    -      Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
    -      operand_inf.add_parameter(data_type.unknown() ?
    -          Type::Class("Object") : data_type);
    -      op_supertypes.push_back(output.iterable() ?
    -          Type::IterableOf(operand_inf) : operand_inf);
    -    }
    -  }
    -  for (const auto& endpoint_def : op_api_.endpoint()) {
    -    std::vector name_tokens = str_util::Split(endpoint_def.name(), ".");
    -    // if the endpoint specifies a package, use it, otherwise derive it from the
    -    // op library name.
    -    string name;
    -    string package;
    -    if (name_tokens.size() > 1) {
    -      package = str_util::Lowercase(name_tokens.at(0));
    -      name = name_tokens.at(1);
    -    } else {
    -      package = str_util::StringReplace(lib_name_, "_", "", true);
    -      name = name_tokens.at(0);
    -    }
    -    Type endpoint(Type::Class(name, base_package_ + "." + package));
    -    Javadoc endpoint_doc(op_doc);
    -    for (const auto& parameter : op_generics) {
    -      endpoint.add_parameter(*parameter.second);
    -    }
    -    for (const Type& supertype : op_supertypes) {
    -      endpoint.add_supertype(supertype);
    -    }
    -    if (endpoint_def.deprecation_version() > 0) {
    -      string explanation;
    -      if (op_api_.endpoint(0).deprecation_version() == 0) {
    -        explanation = ", use {@link "
    -            + op->endpoints().at(0).type().full_name()
    -            + "} instead";
    -      } else {
    -        explanation = op_def_.deprecation().explanation();
    -      }
    -      endpoint_doc.add_tag("deprecated", explanation);
    -      endpoint.add_annotation(Annotation::Create("Deprecated"));
    -    }
    -    // only visible ops should be annotated for exposure in the Ops Graph API
    -    if (op_api_.visibility() != ApiDef::HIDDEN) {
    -      string group_name = SnakeToCamelCase(lib_name_);
    -      endpoint.add_annotation(
    -          Annotation::Create("Operator", "org.tensorflow.op.annotation")
    -            .attributes("group = \"" + group_name + "\""));
    -    }
    -    op->add_endpoint(endpoint, endpoint_doc);
    -  }
    -}
    -
    -void OpParser::ParseInput(const OpDef_ArgDef& input_def,
    -    const ApiDef::Arg& input_api, OpSpec* op) {
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(input_def, &iterable);
    -  Type type = Type::Interface("Operand", "org.tensorflow")
    -    .add_parameter(data_type);
    -  if (iterable) {
    -    type = Type::IterableOf(type);
    -  }
    -  op->add_input(OpSpec::Operand(input_api.name(),
    -      Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(input_api.description()),
    -      iterable));
    -}
    -
    -void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
    -    const ApiDef::Arg& output_api, OpSpec* op) {
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(output_def, &iterable);
    -  Type type = Type::Class("Output", "org.tensorflow")
    -    .add_parameter(data_type);
    -  if (iterable) {
    -    type = Type::ListOf(type);
    -  }
    -  op->add_output(OpSpec::Operand(output_api.name(),
    -      Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(output_api.description()),
    -      iterable));
    -}
    -
    -void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
    -    const ApiDef::Attr& attr_api, OpSpec* op) {
    -  // do not parse attributes already visited, they have probably been inferred
    -  // before as an input argument type
    -  if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
    -    return;
    -  }
    -  bool iterable = false;
    -  Type data_type = DataTypeOf(attr_def, &iterable);
    -  // generic attributes should be passed as an explicit type
    -  bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
    -  Type type = explicit_type ?
    -      Type::Class("Class").add_parameter(data_type) : data_type;
    -  if (iterable) {
    -    type = Type::ListOf(data_type);
    -  }
    -  OpSpec::Operand attr(attr_api.name(),
    -      Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
    -      data_type,
    -      ParseDocumentation(attr_api.description()),
    -      iterable);
    -  // attributes with a default value are optional
    -  if (attr_api.has_default_value() && !explicit_type) {
    -    op->add_option(attr);
    -  } else {
    -    op->add_attribute(attr);
    -  }
    -  visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
    -}
    -
    -Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
    -  if (!arg.number_attr().empty()) {
    -    visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
    -    *iterable_out = true;
    -  }
    -  if (arg.type() != DataType::DT_INVALID) {
    -    // resolve type from DataType
    -    switch (arg.type()) {
    -      case DataType::DT_BOOL:
    -        return Type::Class("Boolean");
    -
    -      case DataType::DT_STRING:
    -        return Type::Class("String");
    -
    -      case DataType::DT_FLOAT:
    -        return Type::Class("Float");
    -
    -      case DataType::DT_DOUBLE:
    -        return Type::Class("Double");
    -
    -      case DataType::DT_UINT8:
    -        return Type::Class("UInt8", "org.tensorflow.types");
    -
    -      case DataType::DT_INT32:
    -        return Type::Class("Integer");
    -
    -      case DataType::DT_INT64:
    -        return Type::Class("Long");
    -
    -      case DataType::DT_RESOURCE:
    -        // TODO(karllessard) create a Resource utility class that could be
    -        // used to store a resource and its type (passed in a second argument).
    -        // For now, we need to force a wildcard and we will unfortunately lose
    -        // track of the resource type.
    -        return Type::Wildcard();
    -
    -      default:
    -        break;
    -    }
    -  } else {
    -    // resolve type from type attribute
    -    string attr_name = arg.type_attr();
    -    if (attr_name.empty()) {
    -      attr_name = arg.type_list_attr();
    -      if (!attr_name.empty()) {
    -        *iterable_out = true;
    -        Type type = Type::Wildcard();
    -        visited_attrs_.insert(std::make_pair(attr_name, type));
    -        return type;
    -      }
    -    }
    -    for (const auto& attr : op_def_.attr()) {
    -      if (attr.name() == attr_name) {
    -        Type type = DataTypeOf(attr, iterable_out);
    -        visited_attrs_.insert(std::make_pair(attr_name, type));
    -        return type;
    -      }
    -    }
    -  }
    -  LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
    -  return Type::Wildcard();
    -}
    -
    -Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
    -  std::map::const_iterator it = visited_attrs_.find(attr.name());
    -  if (it != visited_attrs_.cend()) {
    -    return it->second;
    -  }
    -  string attr_type = attr.type();
    -  if (attr.type().compare(0, 5, "list(") == 0) {
    -    attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
    -    *iterable_out = true;
    -  }
    -  if (attr_type == "type") {
    -    if (*iterable_out) {
    -      return Type::Enum("DataType", "org.tensorflow");
    -    }
    -    return GetNextGenericTensorType(attr.allowed_values());
    -  }
    -  if (attr_type == "string") {
    -    return Type::Class("String");
    -  }
    -  if (attr_type == "int") {
    -    return Type::Class("Integer");
    -  }
    -  if (attr_type == "float") {
    -    return Type::Class("Float");
    -  }
    -  if (attr_type == "bool") {
    -    return Type::Class("Boolean");
    -  }
    -  if (attr_type == "shape") {
    -    return Type::Class("Shape", "org.tensorflow");
    -  }
    -  if (attr_type == "tensor") {
    -    return Type::Class("Tensor", "org.tensorflow")
    -      .add_parameter(Type::Wildcard());
    -  }
    -  LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
    -  return *iterable_out ? Type::Wildcard() : Type::Class("Object");
    -}
    -
    -Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values)  {
    -  Type generic = Type::Generic(string(1, next_generic_));
    -  next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    -
    -  // when only real numbers are allowed, enforce that restriction in the Java by
    -  // extending the generic from java.lang.Number
    -  if (IsRealNumbers(allowed_values)) {
    -    generic.add_supertype(Type::Class("Number"));
    -  }
    -  return generic;
    -}
    -
    -}  // namespace java
    -}  // namespace tensorflow
    diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
    deleted file mode 100644
    index 42855127cc..0000000000
    --- a/tensorflow/java/src/gen/cc/op_parser.h
    +++ /dev/null
    @@ -1,137 +0,0 @@
    -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
    -
    -Licensed under the Apache License, Version 2.0 (the "License");
    -you may not use this file except in compliance with the License.
    -You may obtain a copy of the License at
    -
    -    http://www.apache.org/licenses/LICENSE-2.0
    -
    -Unless required by applicable law or agreed to in writing, software
    -distributed under the License is distributed on an "AS IS" BASIS,
    -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    -See the License for the specific language governing permissions and
    -limitations under the License.
    -==============================================================================*/
    -
    -#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    -#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    -
    -#include 
    -#include 
    -#include 
    -#include 
    -
    -#include "tensorflow/core/framework/op_def.pb.h"
    -#include "tensorflow/core/framework/api_def.pb.h"
    -#include "tensorflow/java/src/gen/cc/java_defs.h"
    -
    -namespace tensorflow {
    -namespace java {
    -
    -// Specification of a TensorFlow operation to generate.
    -//
    -// This is the result of an operation definition parsing, see OpParser::Parse().
    -class OpSpec {
    - public:
    -  class Endpoint {
    -   public:
    -    Endpoint(const Type& type, const Javadoc& javadoc)
    -      : type_(type), javadoc_(javadoc) {}
    -    const Type& type() const { return type_; }
    -    const Javadoc& javadoc() const { return javadoc_; }
    -
    -   private:
    -    Type type_;
    -    Javadoc javadoc_;
    -  };
    -
    -  class Operand {
    -   public:
    -    Operand(const string& graph_name, const Variable& var,
    -        const Type& data_type, const string& description, bool iterable)
    -     : graph_name_(graph_name), var_(var), data_type_(data_type),
    -       description_(description), iterable_(iterable) {}
    -    const string& graph_name() const { return graph_name_; }
    -    const Variable& var() const { return var_; }
    -    Variable* var_ptr() { return &var_; }
    -    const Type& data_type() const { return data_type_; }
    -    const string& description() const { return description_; }
    -    bool iterable() const { return iterable_; }
    -
    -   private:
    -    string graph_name_;
    -    Variable var_;
    -    Type data_type_;
    -    string description_;
    -    bool iterable_;
    -  };
    -
    -  explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
    -  const string& graph_name() const { return graph_name_; }
    -  const std::vector endpoints() const { return endpoints_; }
    -  void add_endpoint(const Type& type, const Javadoc& javadoc) {
    -    endpoints_.push_back(Endpoint(type, javadoc));
    -  }
    -  const std::vector& inputs() const { return inputs_; }
    -  void add_input(const Operand& input) {
    -    inputs_.push_back(input);
    -  }
    -  const std::vector& outputs() const { return outputs_; }
    -  void add_output(const Operand& output) {
    -    outputs_.push_back(output);
    -  }
    -  const std::vector& attributes() const { return attributes_; }
    -  void add_attribute(const Operand& attribute) {
    -    attributes_.push_back(attribute);
    -  }
    -  const std::vector& options() const { return options_; }
    -  void add_option(const Operand& option) {
    -    options_.push_back(option);
    -  }
    -
    - private:
    -  string graph_name_;
    -  std::vector endpoints_;
    -  std::vector inputs_;
    -  std::vector outputs_;
    -  std::vector attributes_;
    -  std::vector options_;
    -};
    -
    -// A parser of ops proto definitions.
    -//
    -// This object parses the definition and the api of an TensorFlow operation to
    -// produce a specification that can be used for Java source code rendering.
    -class OpParser {
    - public:
    -  OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
    -      const string& base_package);
    -  virtual ~OpParser() = default;
    -
    -  // Produces an operation specification from its proto definitions.
    -  void Parse(std::unique_ptr* op_ptr);
    -
    - private:
    -  OpDef op_def_;
    -  ApiDef op_api_;
    -  string lib_name_;
    -  string base_package_;
    -  std::map visited_attrs_;
    -  char next_generic_ = 0;
    -
    -  void BuildEndpoints(OpSpec* op);
    -  void ParseInput(const OpDef_ArgDef& input_def,
    -      const ApiDef::Arg& input_api, OpSpec* op);
    -  void ParseOutput(const OpDef_ArgDef& output_def,
    -      const ApiDef::Arg& output_api, OpSpec* op);
    -  void ParseAttribute(const OpDef_AttrDef& attr_def,
    -      const ApiDef::Attr& attr_api, OpSpec* op);
    -  Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    -  Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    -  Type GetNextGenericTensorType(const AttrValue& allowed_values);
    -};
    -
    -}  // namespace java
    -}  // namespace tensorflow
    -
    -#endif  // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
    diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
    new file mode 100644
    index 0000000000..a727f7ae90
    --- /dev/null
    +++ b/tensorflow/java/src/gen/cc/op_specs.cc
    @@ -0,0 +1,390 @@
    +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
    +
    +Licensed under the Apache License, Version 2.0 (the "License");
    +you may not use this file except in compliance with the License.
    +You may obtain a copy of the License at
    +
    +    http://www.apache.org/licenses/LICENSE-2.0
    +
    +Unless required by applicable law or agreed to in writing, software
    +distributed under the License is distributed on an "AS IS" BASIS,
    +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    +See the License for the specific language governing permissions and
    +limitations under the License.
    +==============================================================================*/
    +
    +#include 
    +#include 
    +#include 
    +#include 
    +
    +#include "re2/re2.h"
    +#include "tensorflow/core/framework/op.h"
    +#include "tensorflow/core/framework/types.h"
    +#include "tensorflow/core/lib/strings/str_util.h"
    +#include "tensorflow/core/platform/logging.h"
    +#include "tensorflow/java/src/gen/cc/op_specs.h"
    +
    +namespace tensorflow {
    +namespace java {
    +namespace {
    +
    +inline bool IsRealNumbers(const AttrValue& values) {
    +  if (!values.has_list()) {
    +    return RealNumberTypes().Contains(values.type());
    +  }
    +  for (int i = 0; i < values.list().type_size(); ++i) {
    +    if (!RealNumberTypes().Contains(values.list().type(i))) {
    +      return false;
    +    }
    +  }
    +  return true;
    +}
    +
    +class TypeResolver {
    + public:
    +  explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
    +
    +  Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
    +  Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
    +  bool IsAttributeVisited(const string& attr_name) {
    +    return visited_attrs_.find(attr_name) != visited_attrs_.cend();
    +  }
    + private:
    +  const OpDef op_def_;
    +  std::map visited_attrs_;
    +  char next_generic_ = 'T';
    +};
    +
    +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
    +    bool* iterable_out) {
    +  *iterable_out = false;
    +  if (!arg_def.number_attr().empty()) {
    +    // when number_attr is set, argument has to be a list of tensors
    +    *iterable_out = true;
    +    visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
    +  }
    +  Type type = Type::Wildcard();
    +  if (arg_def.type() != DataType::DT_INVALID) {
    +    // resolve type from DataType
    +    switch (arg_def.type()) {
    +      case DataType::DT_BOOL:
    +        type = Type::Class("Boolean");
    +        break;
    +      case DataType::DT_STRING:
    +        type = Type::Class("String");
    +        break;
    +      case DataType::DT_FLOAT:
    +        type = Type::Class("Float");
    +        break;
    +      case DataType::DT_DOUBLE:
    +        type = Type::Class("Double");
    +        break;
    +      case DataType::DT_UINT8:
    +        type = Type::Class("UInt8", "org.tensorflow.types");
    +        break;
    +      case DataType::DT_INT32:
    +        type = Type::Class("Integer");
    +        break;
    +      case DataType::DT_INT64:
    +        type = Type::Class("Long");
    +        break;
    +      case DataType::DT_RESOURCE:
    +        // TODO(karllessard) create a Resource utility class that could be
    +        // used to store a resource and its type (passed in a second argument).
    +        // For now, we need to force a wildcard and we will unfortunately lose
    +        // track of the resource type.
    +        break;
    +      default:
    +        // Any other datatypes does not have a equivalent in Java and must
    +        // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
    +        break;
    +    }
    +  } else if (!arg_def.type_attr().empty()) {
    +    // resolve type from attribute (if already visited, retrieve its type)
    +    if (IsAttributeVisited(arg_def.type_attr())) {
    +      type = visited_attrs_.at(arg_def.type_attr());
    +    } else {
    +      for (const auto& attr_def : op_def_.attr()) {
    +        if (attr_def.name() == arg_def.type_attr()) {
    +          type = TypeOf(attr_def, iterable_out);
    +          break;
    +        }
    +      }
    +    }
    +  } else if (!arg_def.type_list_attr().empty()) {
    +    // type is a list of tensors that can be of different data types, so leave
    +    // it as a list of wildcards
    +    *iterable_out = true;
    +    visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
    +
    +  } else {
    +    LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
    +        << "\" in operation \"" << op_def_.name() << "\"";
    +  }
    +  return type;
    +}
    +
    +Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
    +    bool* iterable_out) {
    +  *iterable_out = false;
    +  StringPiece attr_type = attr_def.type();
    +  if (str_util::ConsumePrefix(&attr_type, "list(")) {
    +    attr_type.remove_suffix(1);  // remove closing brace
    +    *iterable_out = true;
    +  }
    +  Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object");
    +  if (attr_type == "type") {
    +    if (*iterable_out) {
    +      type = Type::Enum("DataType", "org.tensorflow");
    +    } else {
    +      type = Type::Generic(string(1, next_generic_));
    +      next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
    +      if (IsRealNumbers(attr_def.allowed_values())) {
    +        // enforce real numbers datasets by extending java.lang.Number
    +        type.add_supertype(Type::Class("Number"));
    +      }
    +    }
    +  } else if (attr_type == "string") {
    +    type = Type::Class("String");
    +
    +  } else if (attr_type == "int") {
    +    type = Type::Class("Integer");
    +
    +  } else if (attr_type == "float") {
    +    type = Type::Class("Float");
    +
    +  } else if (attr_type == "bool") {
    +    type = Type::Class("Boolean");
    +
    +  } else if (attr_type == "shape") {
    +    type = Type::Class("Shape", "org.tensorflow");
    +
    +  } else if (attr_type == "tensor") {
    +    type = Type::Class("Tensor", "org.tensorflow")
    +        .add_parameter(Type::Wildcard());
    +
    +  } else {
    +    LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
    +        << "\" in operation \"" << op_def_.name() << "\"";
    +  }
    +  visited_attrs_.insert(std::make_pair(attr_def.name(), type));
    +  return type;
    +}
    +
    +string SnakeToCamelCase(const string& str, bool upper = false) {
    +  string result;
    +  bool cap = upper;
    +  for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
    +    const char c = *it;
    +    if (c == '_') {
    +      cap = true;
    +    } else if (cap) {
    +      result += toupper(c);
    +      cap = false;
    +    } else {
    +      result += c;
    +    }
    +  }
    +  return result;
    +}
    +
    +bool FindAndCut(re2::StringPiece* input, const RE2& expr,
    +    re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
    +  re2::StringPiece match;
    +  bool matches =
    +      expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1);
    +  if (matches) {
    +    before_match->set(input->data(), match.begin() - input->begin());
    +    input->remove_prefix(match.end() - before_match->begin());
    +    if (ret_match != nullptr) {
    +      *ret_match = match;
    +    }
    +  } else {
    +    *before_match = *input;
    +    if (ret_match != nullptr) {
    +      ret_match->set(nullptr, 0);
    +    }
    +  }
    +  return matches;
    +}
    +
    +string ParseDocumentation(const string& mdtext) {
    +  std::stringstream javadoc_text;
    +  re2::StringPiece input(mdtext);
    +  re2::StringPiece text;
    +  bool in_list = false;
    +  do {
    +    re2::StringPiece markup;
    +    FindAndCut(&input,
    +        "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\[",
    +        &text, &markup);
    +    javadoc_text << text;
    +    if (markup.empty()) {
    +      break;  // we are done parsing
    +    }
    +    if (markup.starts_with("\n")) {
    +      javadoc_text << "\n";
    +      if (markup.contains("* ")) {
    +        javadoc_text << (in_list ? "
  • \n" : "
      \n") << "
    • \n"; + in_list = true; + } else if (markup.starts_with("\n\n")) { + if (in_list) { + javadoc_text << "
    • \n
    \n"; + in_list = false; + } else if (!input.starts_with("```")) { + javadoc_text << "

    \n"; + } + } + } else if (markup.starts_with("```") && text.empty()) { + re2::StringPiece language; + RE2::Consume(&input, "[\\w\\+]+", &language); + if (FindAndCut(&input, markup.ToString() + "\n*", &text)) { + javadoc_text << "

    \n{@code" << text << "}\n
    \n"; + } else { + javadoc_text << markup << language; + } + } else if (markup.starts_with("`")) { + if (FindAndCut(&input, markup, &text)) { + javadoc_text << "{@code " << text << "}"; + } else { + javadoc_text << markup; + } + } else if (markup == "**") { + if (FindAndCut(&input, "\\b\\*{2}", &text)) { + javadoc_text << "" << text << ""; + } else { + javadoc_text << markup; + } + } else if (markup == "*") { + if (FindAndCut(&input, "\\b\\*{1}", &text)) { + javadoc_text << "" << text << ""; + } else { + javadoc_text << markup; + } + } else if (markup == "[") { + string label; + string link; + if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { + javadoc_text << "" << label << ""; + } else { + javadoc_text << markup; + } + } + } while (!input.empty()); + + return javadoc_text.str(); +} + +ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, + const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(input_def, &iterable); + Type var_type = Type::Interface("Operand", "org.tensorflow") + .add_parameter(type); + if (iterable) { + var_type = Type::IterableOf(var_type); + } + return ArgumentSpec(input_api_def.name(), + Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), + type, + ParseDocumentation(input_api_def.description()), + iterable); +} + +AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, + const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(attr_def, &iterable); + // type attributes must be passed explicitly in methods as a Class<> parameter + bool is_explicit = type.kind() == Type::GENERIC && !iterable; + Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type; + if (iterable) { + var_type = Type::ListOf(type); + } + return AttributeSpec(attr_api_def.name(), + Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), + type, + ParseDocumentation(attr_api_def.description()), + iterable, + attr_api_def.has_default_value() && !is_explicit); +} + +ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, + const ApiDef::Arg& output_api, TypeResolver* type_resolver) { + bool iterable = false; + Type type = type_resolver->TypeOf(output_def, &iterable); + Type var_type = Type::Class("Output", "org.tensorflow") + .add_parameter(type); + if (iterable) { + var_type = Type::ListOf(var_type); + } + return ArgumentSpec(output_api.name(), + Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), + type, + ParseDocumentation(output_api.description()), + iterable); +} + +EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, + const ApiDef_Endpoint& endpoint_def) { + + std::vector name_tokens = str_util::Split(endpoint_def.name(), "."); + string package; + string name; + if (name_tokens.size() > 1) { + package = name_tokens.at(0); + name = name_tokens.at(1); + } else { + package = "core"; // generate unclassified ops in the 'core' package + name = name_tokens.at(0); + } + return EndpointSpec(package, + name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description())), + endpoint_def.deprecation_version() > 0); +} + +} // namespace + +OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { + OpSpec op(api_def.graph_op_name(), + api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); + TypeResolver type_resolver(op_def); + for (const string& next_input_name : api_def.arg_order()) { + for (int i = 0; i < op_def.input_arg().size(); ++i) { + if (op_def.input_arg(i).name() == next_input_name) { + op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), + &type_resolver)); + break; + } + } + } + for (int i = 0; i < op_def.attr().size(); ++i) { + // do not parse attributes already visited, they have probably been inferred + // before as an input argument type + if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { + AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), + &type_resolver); + // attributes with a default value are optional + if (attr.optional()) { + op.optional_attributes_.push_back(attr); + } else { + op.attributes_.push_back(attr); + } + } + } + for (int i = 0; i < op_def.output_arg().size(); ++i) { + op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i), + &type_resolver)); + } + for (const auto& endpoint_def : api_def.endpoint()) { + op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); + } + return op; +} + +} // namespace java +} // namespace tensorflow diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h new file mode 100644 index 0000000000..55c2c3f307 --- /dev/null +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ +#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ + +#include +#include + +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/java/src/gen/cc/java_defs.h" + +namespace tensorflow { +namespace java { + +class EndpointSpec { + public: + // A specification for an operation endpoint + // + // package: package of this endpoint (from which also derives its package) + // name: name of this endpoint class + // javadoc: the endpoint class documentation + // deprecated: true if this endpoint is now deprecated + EndpointSpec(const string& package, const string& name, + const Javadoc& javadoc, bool deprecated) + : package_(package), name_(name), javadoc_(javadoc), + deprecated_(deprecated) {} + + const string& package() const { return package_; } + const string& name() const { return name_; } + const Javadoc& javadoc() const { return javadoc_; } + bool deprecated() const { return deprecated_; } + + private: + const string package_; + const string name_; + const Javadoc javadoc_; + const bool deprecated_; +}; + +class ArgumentSpec { + public: + // A specification for an operation argument + // + // op_def_name: argument name, as known by TensorFlow core + // var: a variable to represent this argument in Java + // type: the tensor type of this argument + // description: a description of this argument, in javadoc + // iterable: true if this argument is a list + ArgumentSpec(const string& op_def_name, const Variable& var, + const Type& type, const string& description, bool iterable) + : op_def_name_(op_def_name), var_(var), type_(type), + description_(description), iterable_(iterable) {} + virtual ~ArgumentSpec() = default; + + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } + + private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; +}; + +class AttributeSpec : public ArgumentSpec { + public: + // A specification for an operation attribute + // + // op_def_name: attribute name, as known by TensorFlow core + // var: a variable to represent this attribute in Java + // type: the type of this attribute + // description: a description of this attribute, in javadoc + // iterable: true if this attribute is a list + // optional: true if this attribute does not require to be set explicitly + AttributeSpec(const string& op_def_name, const Variable& var, + const Type& type, const string& description, bool iterable, + bool optional) + : ArgumentSpec(op_def_name, var, type, description, iterable), + optional_(optional) {} + virtual ~AttributeSpec() = default; + + bool optional() const { return optional_; } + + private: + const bool optional_; +}; + +class OpSpec { + public: + // Parses an op definition and its API to produce a specification used for + // rendering its Java wrapper + // + // op_def: Op definition + // api_def: Op API definition + static OpSpec Create(const OpDef& op_def, const ApiDef& api_def); + + const string& graph_op_name() const { return graph_op_name_; } + bool hidden() const { return hidden_; } + const string& deprecation_explanation() const { + return deprecation_explanation_; + } + const std::vector endpoints() const { return endpoints_; } + const std::vector& inputs() const { return inputs_; } + const std::vector& outputs() const { return outputs_; } + const std::vector& attributes() const { return attributes_; } + const std::vector& optional_attributes() const { + return optional_attributes_; + } + + private: + // A specification for an operation + // + // graph_op_name: name of this op, as known by TensorFlow core engine + // hidden: true if this op should not be visible through the Graph Ops API + // deprecation_explanation: message to show if all endpoints are deprecated + explicit OpSpec(const string& graph_op_name, bool hidden, + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} + + const string graph_op_name_; + const bool hidden_; + const string deprecation_explanation_; + std::vector endpoints_; + std::vector inputs_; + std::vector outputs_; + std::vector attributes_; + std::vector optional_attributes_; +}; + +} // namespace java +} // namespace tensorflow + +#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_ diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index b1de5af6ba..7e427787f9 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index 1f0febe9a3..bcae33ccce 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 8bd42d9d0e..875ad99ae2 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -313,8 +313,7 @@ TEST(WriteType, SimpleClassWithDependencies) { TEST(WriteType, AnnotatedAndDocumentedClass) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); - Javadoc clazz_doc; - clazz_doc.brief("Javadoc test") + Javadoc clazz_doc = Javadoc::Create("Javadoc test") .details("This is a\nmultiline description."); clazz.add_annotation(Annotation::Create("Bean")); clazz.add_annotation(Annotation::Create("SuppressWarnings") @@ -329,7 +328,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) { " *

    \n" " * This is a\n" " * multiline description.\n" - " **/\n" + " */\n" "@Bean\n" "@SuppressWarnings(\"rawtypes\")\n" "public class Test {\n}\n"; @@ -378,8 +377,7 @@ TEST(WriteType, ParameterizedClassFields) { Variable field1 = Variable::Create("field1", Type::Class("String")); Variable field2 = Variable::Create("field2", Type::Class("String")); Variable field3 = Variable::Create("field3", type_t); - Javadoc field3_doc; - field3_doc.brief("This variable is documented"); + Javadoc field3_doc = Javadoc::Create("This variable is documented"); writer.BeginType(clazz, PUBLIC) .WriteField(field1, STATIC | PUBLIC | FINAL) @@ -464,8 +462,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { SourceBufferWriter writer; Type clazz = Type::Class("Test", "org.tensorflow"); Method method = Method::Create("doNothing", Type::Void()); - Javadoc method_doc; - method_doc.brief("Javadoc test") + Javadoc method_doc = Javadoc::Create("Javadoc test") .details("This method has a\nmultiline description."); method.add_annotation(Annotation::Create("Override")); method.add_annotation(Annotation::Create("SuppressWarnings") @@ -484,7 +481,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) { " *

    \n" " * This method has a\n" " * multiline description.\n" - " **/\n" + " */\n" " @Override\n" " @SuppressWarnings(\"rawtypes\")\n" " public void doNothing() {\n" @@ -500,8 +497,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { Method method = Method::Create("boolToInt", Type::Int()); method.add_argument(Variable::Create("b", Type::Boolean())); method.add_argument(reverse); - Javadoc method_doc; - method_doc.brief("Converts a boolean to an int") + Javadoc method_doc = Javadoc::Create("Converts a boolean to an int") .details("This method will convert\na boolean to an int") .add_param_tag(reverse.name(), "if true, value is reversed") .add_tag("return", "int value for this boolean"); @@ -528,7 +524,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) { " * \n" " * @param reverse if true, value is reversed\n" " * @return int value for this boolean\n" - " **/\n" + " */\n" " public int boolToInt(boolean b, boolean reverse) {\n" " if (b && !reverse) {\n" " return 1;\n" diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index 1e7899cf7a..7017b52649 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -32,50 +32,52 @@ def tf_java_op_gen_srcjar(name, api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): - gen_tools = [] gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files srcs = api_def_srcs[:] - # Construct an op generator binary for each ops library. - for ops_lib in ops_libs: - gen_lib = ops_lib[:ops_lib.rfind("_")] - out_gen_tool = out_dir + ops_lib + "_gen_tool" + if not api_def_srcs: + api_def_args_str = "," + else: + api_def_args = [] + for api_def_src in api_def_srcs: + # Add directory of the first ApiDef source to args. + # We are assuming all ApiDefs in a single api_def_src are in the + # same directory. + api_def_args.append( + "$$(dirname $$(echo $(locations " + api_def_src + + ") | cut -d\" \" -f1))") + api_def_args_str = ",".join(api_def_args) - if not api_def_srcs: - api_def_args_str = "," - else: - api_def_args = [] - for api_def_src in api_def_srcs: - # Add directory of the first ApiDef source to args. - # We are assuming all ApiDefs in a single api_def_src are in the - # same directory. - api_def_args.append( - " $$(dirname $$(echo $(locations " + api_def_src + - ") | cut -d\" \" -f1))") - api_def_args_str = ",".join(api_def_args) + gen_tool_deps = [":java_op_gen_lib"] + for ops_lib in ops_libs: + gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib") - tf_cc_binary( - name=out_gen_tool, - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"]) + tf_cc_binary( + name=gen_tool, + srcs=[ + "src/gen/cc/op_gen_main.cc", + ], + copts=tf_copts(), + linkopts=["-lm"], + linkstatic=1, # Faster to link this one-time-use binary dynamically + deps = gen_tool_deps) - gen_tools += [":" + out_gen_tool] - gen_cmds += ["$(location :" + out_gen_tool + ")" + - " --output_dir=$(@D)/" + out_src_dir + - " --lib_name=" + gen_lib + - " --base_package=" + gen_base_package + - " " + api_def_args_str] + gen_cmds += ["$(location :" + gen_tool + ")" + + " --output_dir=$(@D)/" + out_src_dir + + " --base_package=" + gen_base_package + + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. gen_srcjar = out_dir + name + ".srcjar" gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"] - gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"] - gen_tools += tf_binary_additional_srcs() + native.genrule( name=name, srcs=srcs, outs=[gen_srcjar], - tools=gen_tools, - cmd="&&".join(gen_cmds)) + tools=[ + "@local_jdk//:jar", + "@local_jdk//:jdk", + gen_tool + ] + tf_binary_additional_srcs(), + cmd=" && ".join(gen_cmds)) diff --git a/tensorflow/java/src/gen/resources/license.java.snippet b/tensorflow/java/src/gen/resources/license.java.snippet new file mode 100644 index 0000000000..90285ec669 --- /dev/null +++ b/tensorflow/java/src/gen/resources/license.java.snippet @@ -0,0 +1,14 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java deleted file mode 100644 index 90285ec669..0000000000 --- a/tensorflow/java/src/gen/resources/license.snippet.java +++ /dev/null @@ -1,14 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -- cgit v1.2.3