aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-04-06 08:56:54 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-05-03 22:40:42 -0400
commit7e80197f020895fea41eda36b08135b747a9a4f1 (patch)
treee9ddf69e5b0cddd490472441ee5051c739c6f8bd /tensorflow/java
parentfe9b2637cfe39cf11eb3d0494948a733b7fc1d7d (diff)
Improve Javadoc and include first code review
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD23
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h12
-rw-r--r--tensorflow/java/src/gen/cc/op_gen_main.cc48
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc224
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h25
-rw-r--r--tensorflow/java/src/gen/cc/op_parser.cc417
-rw-r--r--tensorflow/java/src/gen/cc/op_parser.h137
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc390
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h152
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc2
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.h2
-rw-r--r--tensorflow/java/src/gen/cc/source_writer_test.cc20
-rw-r--r--tensorflow/java/src/gen/gen_ops.bzl68
-rw-r--r--tensorflow/java/src/gen/resources/license.java.snippet (renamed from tensorflow/java/src/gen/resources/license.snippet.java)0
14 files changed, 760 insertions, 760 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 635a4e807d..17566e1a9c 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -68,9 +68,13 @@ filegroup(
],
)
+# Build the gen tool as a library, as it will be linked to a core/ops binary
+# files before making it an executable.
tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
- api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
+ api_def_srcs = [
+ "//tensorflow/core/api_def:base_api_def",
+ ],
gen_base_package = "org.tensorflow.op",
gen_tool = "java_op_gen_tool",
ops_libs = [
@@ -95,30 +99,17 @@ tf_java_op_gen_srcjar(
],
)
-# Build the gen tool as a library, as it will be linked to a core/ops binary
-# file before making it an executable. See tf_java_op_gen_srcjar().
-cc_library(
- name = "java_op_gen_tool",
- srcs = [
- "src/gen/cc/op_gen_main.cc",
- ],
- copts = tf_copts(),
- deps = [
- ":java_op_gen_lib",
- ],
-)
-
cc_library(
name = "java_op_gen_lib",
srcs = [
"src/gen/cc/op_generator.cc",
- "src/gen/cc/op_parser.cc",
+ "src/gen/cc/op_specs.cc",
"src/gen/cc/source_writer.cc",
],
hdrs = [
"src/gen/cc/java_defs.h",
"src/gen/cc/op_generator.h",
- "src/gen/cc/op_parser.h",
+ "src/gen/cc/op_specs.h",
"src/gen/cc/source_writer.h",
],
copts = tf_copts(),
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index 2065477f58..81ac67eb2f 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -230,12 +230,12 @@ class Javadoc {
return Javadoc(brief);
}
const string& brief() const { return brief_; }
- const string& details() const { return description_; }
- Javadoc& details(const string description) {
- description_ = description;
+ const string& details() const { return details_; }
+ Javadoc& details(const string& details) {
+ details_ = details;
return *this;
}
- const std::list<std::pair<string, string>> tags() const { return tags_; }
+ const std::list<std::pair<string, string>>& tags() const { return tags_; }
Javadoc& add_tag(const string& tag, const string& text) {
tags_.push_back(std::make_pair(tag, text));
return *this;
@@ -246,7 +246,7 @@ class Javadoc {
private:
string brief_;
- string description_;
+ string details_;
std::list<std::pair<string, string>> tags_;
explicit Javadoc(const string& brief) : brief_(brief) {}
diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc
index 015200023f..458141b877 100644
--- a/tensorflow/java/src/gen/cc/op_gen_main.cc
+++ b/tensorflow/java/src/gen/cc/op_gen_main.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -36,55 +36,41 @@ const char kUsageHeader[] =
"Operation wrappers are generated under the path specified by the "
"'--output_dir' argument. This path can be absolute or relative to the\n"
"current working directory and will be created if it does not exists.\n\n"
- "The '--lib_name' argument is used to classify the set of operations. If "
- "the chosen name contains more than one word, it must be provided in \n"
- "snake_case. This value is declined into other meaningful names, such as "
- "the group and package of the generated operations. For example,\n"
- "'--lib_name=my_lib' generates the operations under the "
- "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
- "group.\n\n"
- "Note that the operator group assigned to the generated wrappers is just "
- "an annotation tag at this stage. Operations will not be available "
- "through\n"
- "the 'org.tensorflow.op.Ops' API as a group until the generated classes "
- "are compiled using an appropriate annotation processor.\n\n"
+ "Note that the operations will not be available through the "
+ "'org.tensorflow.op.Ops' API until the generated classes are compiled\n"
+ "using an appropriate annotation processor.\n\n"
"The '--base_package' overrides the default parent package under which "
"the generated subpackage and classes are to be located.\n\n"
- "Finally, a list of directories of API proto definitions can be provided "
- "to override default values found in the ops definitions, ordered by\n"
- "priority (the last having precedence over the first).\n\n";
+ "Finally, the `--api_dirs` argument takes a list of comma-seperated "
+ "directories of API definitions can be provided to override default\n"
+ "values found in the ops definitions. Directories are ordered by priority "
+ "(the last having precedence over the first).\n\n";
} // namespace java
} // namespace tensorflow
int main(int argc, char* argv[]) {
- tensorflow::string lib_name;
tensorflow::string output_dir;
tensorflow::string base_package = "org.tensorflow.op";
+ tensorflow::string api_dirs_str;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
"Root directory into which output files are generated"),
- tensorflow::Flag(
- "lib_name", &lib_name,
- "A name, in snake_case, used to classify this set of operations"),
- tensorflow::Flag(
- "base_package", &base_package,
- "Package parent to the generated subpackage and classes")};
+ tensorflow::Flag("base_package", &base_package,
+ "Package parent to the generated subpackage and classes"),
+ tensorflow::Flag("api_dirs", &api_dirs_str,
+ "List of directories that contains the ops api definitions")};
tensorflow::string usage = tensorflow::java::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
- QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
- std::vector<tensorflow::string> api_dirs;
- if (argc > 1) {
- api_dirs = tensorflow::str_util::Split(argv[1], ",",
- tensorflow::str_util::SkipEmpty());
- }
+ QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
+ std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
+ api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs);
tensorflow::OpList ops;
tensorflow::OpRegistry::Global()->Export(false, &ops);
- tensorflow::Status status = generator.Run(ops, lib_name);
- TF_QCHECK_OK(status);
+ TF_CHECK_OK(generator.Run(ops));
return 0;
}
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index c9b57f5706..c32ad3b109 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include <list>
#include <memory>
+#include <set>
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -27,15 +28,15 @@ limitations under the License.
#include "tensorflow/core/framework/op_gen_lib.h"
#include "tensorflow/java/src/gen/cc/java_defs.h"
#include "tensorflow/java/src/gen/cc/source_writer.h"
-#include "tensorflow/java/src/gen/cc/op_parser.h"
#include "tensorflow/java/src/gen/cc/op_generator.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
namespace java {
namespace {
const char* kLicenseSnippet =
- "tensorflow/java/src/gen/resources/license.snippet.java";
+ "tensorflow/java/src/gen/resources/license.java.snippet";
const std::map<string, Type> kPrimitiveAttrTypes = {
{ "Boolean", Type::Boolean() },
@@ -66,34 +67,34 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
}
// Don't pay attention to duplicate types in the dependency list, they will
// be filtered out by the SourceWriter.
- for (const OpSpec::Operand& input : op.inputs()) {
+ for (const ArgumentSpec& input : op.inputs()) {
out->push_back(input.var().type());
if (input.iterable()) {
out->push_back(Type::Class("Operands", "org.tensorflow.op"));
}
}
- for (const OpSpec::Operand& output : op.outputs()) {
+ for (const ArgumentSpec& output : op.outputs()) {
out->push_back(output.var().type());
if (output.iterable()) {
out->push_back(Type::Class("Arrays", "java.util"));
}
}
- for (const OpSpec::Operand& attribute : op.attributes()) {
+ for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
if (attribute.var().type().name() == "Class") {
out->push_back(Type::Enum("DataType", "org.tensorflow"));
}
}
- for (const OpSpec::Operand& option : op.options()) {
- out->push_back(option.var().type());
+ for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+ out->push_back(optional_attribute.var().type());
}
}
-void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
+void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
SourceWriter* writer) {
string var = optional ? "opts." + attr.var().name() : attr.var().name();
if (attr.iterable()) {
- const Type& type = attr.data_type();
+ const Type& type = attr.type();
std::map<string, Type>::const_iterator it =
kPrimitiveAttrTypes.find(type.name());
if (it != kPrimitiveAttrTypes.end()) {
@@ -107,11 +108,11 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
.Append(array + "[i] = " + var + ".get(i);")
.EndLine()
.EndBlock()
- .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array)
+ .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + array)
.Append(");")
.EndLine();
} else {
- writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var)
+ writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", " + var)
.Append(".toArray(new ")
.AppendType(type)
.Append("[" + var + ".size()]));")
@@ -119,7 +120,7 @@ void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
}
} else {
Type type = attr.var().type();
- writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", ");
+ writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
if (type.name() == "Class") {
writer->Append("DataType.fromClass(" + attr.var().name() + "));");
} else {
@@ -139,26 +140,26 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
factory.add_argument(scope);
factory_doc.add_param_tag(scope.name(), "Current graph scope");
- for (const OpSpec::Operand& input : op.inputs()) {
+ for (const ArgumentSpec& input : op.inputs()) {
factory.add_argument(input.var());
factory_doc.add_param_tag(input.var().name(), input.description());
}
- for (const OpSpec::Operand& attribute : op.attributes()) {
+ for (const AttributeSpec& attribute : op.attributes()) {
factory.add_argument(attribute.var());
factory_doc.add_param_tag(attribute.var().name(), attribute.description());
}
- if (!op.options().empty()) {
+ if (!op.optional_attributes().empty()) {
factory.add_argument(Variable::Varargs("options", Type::Class("Options")));
factory_doc.add_param_tag("options", "carries optional attributes values");
}
factory_doc.add_tag("return", "a new instance of " + op_class.name());
writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
- + op.graph_name() + "\", scope.makeOpName(\""
+ + op.graph_op_name() + "\", scope.makeOpName(\""
+ op_class.name() + "\"));");
writer->EndLine();
- for (const OpSpec::Operand& input : op.inputs()) {
+ for (const ArgumentSpec& input : op.inputs()) {
if (input.iterable()) {
writer->Append("opBuilder.addInputList(Operands.asOutputs("
+ input.var().name() + "));");
@@ -169,15 +170,15 @@ void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
writer->EndLine();
}
}
- for (const OpSpec::Operand& attribute : op.attributes()) {
+ for (const AttributeSpec& attribute : op.attributes()) {
WriteSetAttrDirective(attribute, false, writer);
}
- if (!op.options().empty()) {
+ if (!op.optional_attributes().empty()) {
writer->BeginBlock("if (options != null)")
.BeginBlock("for (Options opts : options)");
- for (const OpSpec::Operand& option : op.options()) {
- writer->BeginBlock("if (opts." + option.var().name() + " != null)");
- WriteSetAttrDirective(option, true, writer);
+ for (const AttributeSpec& attribute : op.optional_attributes()) {
+ writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
+ WriteSetAttrDirective(attribute, true, writer);
writer->EndBlock();
}
writer->EndBlock().EndBlock();
@@ -195,8 +196,8 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
.add_argument(
Variable::Create("operation",
Type::Class("Operation", "org.tensorflow")));
- for (const OpSpec::Operand& output : op.outputs()) {
- if (output.iterable() && !output.data_type().unknown()) {
+ for (const ArgumentSpec& output : op.outputs()) {
+ if (output.iterable() && !output.type().unknown()) {
constructor.add_annotation(
Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
break;
@@ -208,15 +209,15 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
if (op.outputs().size() > 0) {
writer->Append("int outputIdx = 0;")
.EndLine();
- for (const OpSpec::Operand& output : op.outputs()) {
+ for (const ArgumentSpec& output : op.outputs()) {
if (output.iterable()) {
string var_length = output.var().name() + "Length";
writer->Append("int " + var_length)
- .Append(" = operation.outputListLength(\"" + output.graph_name()
+ .Append(" = operation.outputListLength(\"" + output.op_def_name()
+ "\");")
.EndLine()
.Append(output.var().name() + " = Arrays.asList(");
- if (!output.data_type().unknown()) {
+ if (!output.type().unknown()) {
writer->Append("(")
.AppendType(output.var().type().parameters().front())
.Append("[])");
@@ -236,18 +237,19 @@ void RenderConstructor(const OpSpec& op, const Type& op_class,
}
void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
- for (const OpSpec::Operand& option : op.options()) {
- Method setter = Method::Create(option.var().name(), Type::Class("Options"))
- .add_argument(option.var());
+ for (const AttributeSpec& attribute : op.optional_attributes()) {
+ Method setter =
+ Method::Create(attribute.var().name(), Type::Class("Options"))
+ .add_argument(attribute.var());
Javadoc setter_doc = Javadoc::Create()
- .add_param_tag(option.var().name(), option.description());
+ .add_param_tag(attribute.var().name(), attribute.description());
writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
- .Append("return new Options()." + option.var().name() + "("
- + option.var().name() + ");")
+ .Append("return new Options()." + attribute.var().name() + "("
+ + attribute.var().name() + ");")
.EndLine()
.EndMethod();
}
- for (const OpSpec::Operand& output : op.outputs()) {
+ for (const ArgumentSpec& output : op.outputs()) {
Method getter = Method::Create(output.var().name(), output.var().type());
Javadoc getter_doc = Javadoc::Create(output.description());
writer->BeginMethod(getter, PUBLIC, &getter_doc)
@@ -259,12 +261,12 @@ void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
SourceWriter* writer) {
- OpSpec::Operand output = op.outputs().front();
+ ArgumentSpec output = op.outputs().front();
if (mode == SINGLE_OUTPUT) {
- bool cast2obj = output.data_type().unknown();
+ bool cast2obj = output.type().unknown();
Type return_type = Type::Class("Output", "org.tensorflow")
- .add_parameter(cast2obj ? Type::Class("Object") : output.data_type());
+ .add_parameter(cast2obj ? Type::Class("Object") : output.type());
Method as_output = Method::Create("asOutput", return_type)
.add_annotation(Annotation::Create("Override"));
if (cast2obj) {
@@ -283,10 +285,10 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
} else if (mode == SINGLE_LIST_OUTPUT) {
Type operand = Type::Interface("Operand", "org.tensorflow");
- if (output.data_type().unknown()) {
+ if (output.type().unknown()) {
operand.add_parameter(Type::Class("Object"));
} else {
- operand.add_parameter(output.data_type());
+ operand.add_parameter(output.type());
}
Type return_type = Type::Interface("Iterator", "java.util")
.add_parameter(operand);
@@ -308,57 +310,119 @@ void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) {
Javadoc options_doc = Javadoc::Create(
"Class holding optional attributes of this operation");
writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
- for (const OpSpec::Operand& option : op.options()) {
- Method setter = Method::Create(option.var().name(), options_class)
- .add_argument(option.var());
+ for (const AttributeSpec& attribute : op.optional_attributes()) {
+ Method setter = Method::Create(attribute.var().name(), options_class)
+ .add_argument(attribute.var());
Javadoc setter_doc = Javadoc::Create()
- .add_param_tag(option.var().name(), option.description());
+ .add_param_tag(attribute.var().name(), attribute.description());
writer->BeginMethod(setter, PUBLIC, &setter_doc)
- .Append("this." + option.var().name() + " = " + option.var().name()
- + ";")
+ .Append("this." + attribute.var().name() + " = "
+ + attribute.var().name() + ";")
.EndLine()
.Append("return this;")
.EndLine()
.EndMethod();
}
writer->EndLine();
- for (const OpSpec::Operand& option : op.options()) {
- writer->WriteField(option.var(), PRIVATE);
+ for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+ writer->WriteField(optional_attribute.var(), PRIVATE);
}
Method constructor = Method::ConstructorFor(options_class);
writer->BeginMethod(constructor, PRIVATE).EndMethod();
writer->EndType();
}
-void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint,
- SourceWriter* writer) {
+inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
+ return Type::Class(endpoint.name(),
+ base_package + "." + str_util::Lowercase(endpoint.package()));
+}
+
+void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
+ const string& base_package, const string& output_dir, Env* env) {
+ Type op_class(ClassOf(endpoint, base_package)
+ .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
+ Javadoc op_javadoc(endpoint.javadoc());
+
+ // implement Operand (or Iterable<Operand>) if the op has only one output
RenderMode mode = DEFAULT;
if (op.outputs().size() == 1) {
- mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT;
+ const ArgumentSpec& output = op.outputs().front();
+ Type operand_type(output.type().unknown() ?
+ Type::Class("Object") : output.type());
+ Type operand_inf(Type::Interface("Operand", "org.tensorflow")
+ .add_parameter(operand_type));
+ if (output.iterable()) {
+ mode = SINGLE_LIST_OUTPUT;
+ op_class.add_supertype(Type::IterableOf(operand_inf));
+ } else {
+ mode = SINGLE_OUTPUT;
+ op_class.add_supertype(operand_inf);
+ }
+ }
+ // declare all outputs generics at the op class level
+ std::set<string> generics;
+ for (const ArgumentSpec& output : op.outputs()) {
+ if (output.type().kind() == Type::GENERIC && !output.type().unknown()
+ && generics.find(output.type().name()) == generics.end()) {
+ op_class.add_parameter(output.type());
+ op_javadoc.add_param_tag("<" + output.type().name() + ">",
+ "data type of output {@code " + output.var().name() + "}");
+ generics.insert(output.type().name());
+ }
+ }
+ // handle endpoint deprecation
+ if (endpoint.deprecated()) {
+ op_class.add_annotation(Annotation::Create("Deprecated"));
+ string explanation;
+ if (!op.endpoints().front().deprecated()) {
+ explanation = "use {@link " +
+ ClassOf(op.endpoints().front(), base_package).full_name()
+ + "} instead";
+ } else {
+ explanation = op.deprecation_explanation();
+ }
+ op_javadoc.add_tag("deprecated", explanation);
}
+ // expose the op in the Ops Graph API only if it is visible
+ if (!op.hidden()) {
+ op_class.add_annotation(
+ Annotation::Create("Operator", "org.tensorflow.op.annotation")
+ .attributes("group = \"" + endpoint.package() + "\""));
+ }
+ // create op class file
+ string op_dir = io::JoinPath(output_dir,
+ str_util::StringReplace(op_class.package(), ".", "/", true));
+ if (!env->FileExists(op_dir).ok()) {
+ TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir));
+ }
+ std::unique_ptr<tensorflow::WritableFile> op_file;
+ TF_CHECK_OK(env->NewWritableFile(
+ io::JoinPath(op_dir, op_class.name() + ".java"), &op_file));
+
+ // render endpoint source code
+ SourceFileWriter writer(op_file.get());
std::list<Type> dependencies;
CollectOpDependencies(op, mode, &dependencies);
- const Type& op_class = endpoint.type();
- writer->WriteFromFile(kLicenseSnippet)
+ writer.WriteFromFile(kLicenseSnippet)
.EndLine()
.Append("// This file is machine generated, DO NOT EDIT!")
.EndLine()
.EndLine()
- .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc());
- if (!op.options().empty()) {
- RenderOptionsClass(op, writer);
+ .BeginType(op_class, PUBLIC|FINAL, &dependencies, &op_javadoc);
+ if (!op.optional_attributes().empty()) {
+ RenderOptionsClass(op, &writer);
}
- RenderFactoryMethod(op, op_class, writer);
- RenderGettersAndSetters(op, writer);
+ RenderFactoryMethod(op, op_class, &writer);
+ RenderGettersAndSetters(op, &writer);
if (mode != DEFAULT) {
- RenderInterfaceImpl(op, mode, writer);
+ RenderInterfaceImpl(op, mode, &writer);
}
- writer->EndLine();
- for (const OpSpec::Operand& output : op.outputs()) {
- writer->WriteField(output.var(), PRIVATE);
+ writer.EndLine();
+ for (const ArgumentSpec& output : op.outputs()) {
+ writer.WriteField(output.var(), PRIVATE);
}
- RenderConstructor(op, op_class, writer);
- writer->EndType();
+ RenderConstructor(op, op_class, &writer);
+ writer.EndType();
}
} // namespace
@@ -369,8 +433,7 @@ OpGenerator::OpGenerator(const string& base_package, const string& output_dir,
env_(env) {
}
-Status OpGenerator::Run(const OpList& op_list, const string& lib_name) {
- LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
+Status OpGenerator::Run(const OpList& op_list) {
ApiDefMap api_map(op_list);
if (!api_dirs_.empty()) {
// Only load api files that correspond to the requested "op_list"
@@ -388,37 +451,14 @@ Status OpGenerator::Run(const OpList& op_list, const string& lib_name) {
for (const auto& op_def : op_list.op()) {
const ApiDef* api_def = api_map.GetApiDef(op_def.name());
if (api_def->visibility() != ApiDef::SKIP) {
- Status status = GenerateOp(op_def, *api_def, lib_name);
- if (status != Status::OK()) {
- LOG(ERROR) << "Fail to generate Java wrapper for operation \""
- << op_def.name() << "\"";
+ OpSpec op(OpSpec::Create(op_def, *api_def));
+ for (const EndpointSpec& endpoint : op.endpoints()) {
+ GenerateOp(op, endpoint, base_package_, output_dir_, env_);
}
}
}
return Status::OK();
}
-Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def,
- const string& lib_name) {
- std::unique_ptr<OpSpec> op;
- OpParser op_parser(op_def, api_def, lib_name, base_package_);
- op_parser.Parse(&op);
- for (const OpSpec::Endpoint& endpoint : op->endpoints()) {
- string package_path = io::JoinPath(output_dir_,
- str_util::StringReplace(endpoint.type().package(), ".", "/", true));
- if (!env_->FileExists(package_path).ok()) {
- TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path));
- }
- string file_path =
- io::JoinPath(package_path, endpoint.type().name() + ".java");
- std::unique_ptr<tensorflow::WritableFile> file;
- TF_CHECK_OK(env_->NewWritableFile(file_path, &file));
-
- SourceFileWriter writer(file.get());
- RenderEndpoint(*op, endpoint, &writer);
- }
- return Status::OK();
-}
-
} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 19d8db95fb..06b08e852a 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -23,36 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
namespace java {
// A generator of Java operation wrappers.
//
-// Such generator is normally ran only once per executable, outputting
-// wrappers for the all registered operations it has been compiled with.
-// Nonetheless, it is designed to support multiple runs, giving a different
-// list of operations on each cycle.
+// This generator takes a list of ops definitions in input and outputs
+// a Java Op wrapper for each of them in the provided directory. The same
+// generator instance can be invoked multiple times with a different list of
+// ops definitions.
class OpGenerator {
public:
OpGenerator(const string& base_package, const string& output_dir,
const std::vector<string>& api_dirs, Env* env = Env::Default());
- virtual ~OpGenerator() = default;
// Generates wrappers for the given list of 'ops'.
//
// Output files are generated in <output_dir>/<base_package>/<lib_package>,
- // where 'lib_package' is derived from 'lib_name'.
- Status Run(const OpList& op_list, const string& lib_name);
+ // where 'lib_package' is derived from ops endpoints.
+ Status Run(const OpList& op_list);
private:
- string base_package_;
- string output_dir_;
- std::vector<string> api_dirs_;
+ const string base_package_;
+ const string output_dir_;
+ const std::vector<string> api_dirs_;
Env* env_;
-
- Status GenerateOp(const OpDef& op_def, const ApiDef& api_def,
- const string& lib_name);
};
} // namespace java
diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc
deleted file mode 100644
index 0541e343d8..0000000000
--- a/tensorflow/java/src/gen/cc/op_parser.cc
+++ /dev/null
@@ -1,417 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include <map>
-#include <vector>
-#include <utility>
-#include <string>
-
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/java/src/gen/cc/op_parser.h"
-
-namespace tensorflow {
-namespace java {
-namespace {
-
-string SnakeToCamelCase(const string& str, bool upper = false) {
- string result;
- bool cap = upper;
- for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
- const char c = *it;
- if (c == '_') {
- cap = true;
- } else if (cap) {
- result += toupper(c);
- cap = false;
- } else {
- result += c;
- }
- }
- return result;
-}
-
-bool IsRealNumber(DataType type) {
- for (DataType dt : RealNumberTypes()) {
- if (type == dt) {
- return true;
- }
- }
- return false;
-}
-
-bool IsRealNumbers(const AttrValue& values) {
- if (values.has_list()) {
- for (int i = 0; i < values.list().type_size(); ++i) {
- if (!IsRealNumber(values.list().type(i))) {
- return false;
- }
- }
- return true;
- }
- return IsRealNumber(values.type());
-}
-
-string ParseDocumentation(const string& text) {
- std::stringstream javadoc_text;
- string::const_iterator c_iter = text.cbegin();
- bool code = false;
- bool emphasis = false;
- bool list = false;
- while (c_iter != text.cend()) {
- char c = *c_iter++;
- int count = 1;
- switch (c) {
- case '\n':
- if (!code) {
- // consumes all subsequent newlines, if there are more than one,
- // then there are two choices:
- // - if the next line starts with an asterisk, we are enumerating
- // a list of items
- // - otherwise, we are starting a new paragraph
- for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {}
- if (c_iter != text.cend()) {
- if (count > 1) {
- if (*c_iter != '*' && list) {
- javadoc_text << "</li>\n</ul>\n";
- list = false;
- } else if (*c_iter == '*' && !list) {
- javadoc_text << "\n<ul>\n<li>";
- list = true;
- c_iter++;
- } else {
- javadoc_text << "\n<p>\n";
- }
- } else if (list && *c_iter == '*') {
- javadoc_text << "</li>\n<li>";
- c_iter++;
- } else {
- javadoc_text << '\n';
- }
- }
- }
- break;
- case '`':
- // consumes all subsequent backquotes, those are use enclose code.
- // if there are more than 3, we are dealing with a pre-formatted block,
- // otherwise it is a single-line code snippet
- for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {}
- if (count >= 3) {
- javadoc_text << (code ? "\n}</pre>" : "<pre>{@code\n");
- } else {
- javadoc_text << (code ? "}" : "{@code ");
- }
- code = !code;
- break;
- case '*':
- if (!code) {
- // consumes all subsequent asterisks, if there are more than one, then
- // we put the text in bold, otherwise in italic
- for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
- if (count > 1) {
- javadoc_text << (emphasis ? "</b>" : "<b>");
- } else {
- javadoc_text << (emphasis ? "</i>" : "<i>");
- }
- emphasis = !emphasis;
- } else {
- javadoc_text << '*';
- }
- break;
- default:
- javadoc_text << c;
- break;
- }
- }
- return javadoc_text.str();
-}
-
-} // namespace
-
-OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
- const string& lib_name, const string& base_package)
- : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
- base_package_(base_package) {
-}
-
-void OpParser::Parse(std::unique_ptr<OpSpec>* op_ptr) {
- visited_attrs_.clear();
- next_generic_ = 'T';
- op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
- for (const string& next_input_name : op_api_.arg_order()) {
- for (int i = 0; i < op_def_.input_arg().size(); ++i) {
- if (op_def_.input_arg(i).name() == next_input_name) {
- ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
- break;
- }
- }
- }
- for (int i = 0; i < op_def_.attr().size(); ++i) {
- ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
- }
- for (int i = 0; i < op_def_.output_arg().size(); ++i) {
- ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
- }
- BuildEndpoints(op_ptr->get());
-}
-
-void OpParser::BuildEndpoints(OpSpec* op) {
- Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
- .details(ParseDocumentation(op_api_.description()));
- std::vector<Type> op_supertypes;
- op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
- std::map<string, const Type*> op_generics;
- for (const OpSpec::Operand& output : op->outputs()) {
- // declare generic output parameters at the Op class level
- const Type& data_type = output.data_type();
- if (data_type.kind() == Type::GENERIC && !data_type.unknown()
- && op_generics.find(data_type.name()) == op_generics.end()) {
- op_generics.insert(std::make_pair(data_type.name(), &data_type));
- op_doc.add_param_tag("<" + data_type.name() + ">",
- "data type of output '" + output.var().name() + "'");
- }
- // implement the Op as an (iteration of) Operand if it has only one output
- if (op->outputs().size() == 1) {
- Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
- operand_inf.add_parameter(data_type.unknown() ?
- Type::Class("Object") : data_type);
- op_supertypes.push_back(output.iterable() ?
- Type::IterableOf(operand_inf) : operand_inf);
- }
- }
- for (const auto& endpoint_def : op_api_.endpoint()) {
- std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
- // if the endpoint specifies a package, use it, otherwise derive it from the
- // op library name.
- string name;
- string package;
- if (name_tokens.size() > 1) {
- package = str_util::Lowercase(name_tokens.at(0));
- name = name_tokens.at(1);
- } else {
- package = str_util::StringReplace(lib_name_, "_", "", true);
- name = name_tokens.at(0);
- }
- Type endpoint(Type::Class(name, base_package_ + "." + package));
- Javadoc endpoint_doc(op_doc);
- for (const auto& parameter : op_generics) {
- endpoint.add_parameter(*parameter.second);
- }
- for (const Type& supertype : op_supertypes) {
- endpoint.add_supertype(supertype);
- }
- if (endpoint_def.deprecation_version() > 0) {
- string explanation;
- if (op_api_.endpoint(0).deprecation_version() == 0) {
- explanation = ", use {@link "
- + op->endpoints().at(0).type().full_name()
- + "} instead";
- } else {
- explanation = op_def_.deprecation().explanation();
- }
- endpoint_doc.add_tag("deprecated", explanation);
- endpoint.add_annotation(Annotation::Create("Deprecated"));
- }
- // only visible ops should be annotated for exposure in the Ops Graph API
- if (op_api_.visibility() != ApiDef::HIDDEN) {
- string group_name = SnakeToCamelCase(lib_name_);
- endpoint.add_annotation(
- Annotation::Create("Operator", "org.tensorflow.op.annotation")
- .attributes("group = \"" + group_name + "\""));
- }
- op->add_endpoint(endpoint, endpoint_doc);
- }
-}
-
-void OpParser::ParseInput(const OpDef_ArgDef& input_def,
- const ApiDef::Arg& input_api, OpSpec* op) {
- bool iterable = false;
- Type data_type = DataTypeOf(input_def, &iterable);
- Type type = Type::Interface("Operand", "org.tensorflow")
- .add_parameter(data_type);
- if (iterable) {
- type = Type::IterableOf(type);
- }
- op->add_input(OpSpec::Operand(input_api.name(),
- Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
- data_type,
- ParseDocumentation(input_api.description()),
- iterable));
-}
-
-void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
- const ApiDef::Arg& output_api, OpSpec* op) {
- bool iterable = false;
- Type data_type = DataTypeOf(output_def, &iterable);
- Type type = Type::Class("Output", "org.tensorflow")
- .add_parameter(data_type);
- if (iterable) {
- type = Type::ListOf(type);
- }
- op->add_output(OpSpec::Operand(output_api.name(),
- Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
- data_type,
- ParseDocumentation(output_api.description()),
- iterable));
-}
-
-void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
- const ApiDef::Attr& attr_api, OpSpec* op) {
- // do not parse attributes already visited, they have probably been inferred
- // before as an input argument type
- if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
- return;
- }
- bool iterable = false;
- Type data_type = DataTypeOf(attr_def, &iterable);
- // generic attributes should be passed as an explicit type
- bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
- Type type = explicit_type ?
- Type::Class("Class").add_parameter(data_type) : data_type;
- if (iterable) {
- type = Type::ListOf(data_type);
- }
- OpSpec::Operand attr(attr_api.name(),
- Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
- data_type,
- ParseDocumentation(attr_api.description()),
- iterable);
- // attributes with a default value are optional
- if (attr_api.has_default_value() && !explicit_type) {
- op->add_option(attr);
- } else {
- op->add_attribute(attr);
- }
- visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
-}
-
-Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
- if (!arg.number_attr().empty()) {
- visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
- *iterable_out = true;
- }
- if (arg.type() != DataType::DT_INVALID) {
- // resolve type from DataType
- switch (arg.type()) {
- case DataType::DT_BOOL:
- return Type::Class("Boolean");
-
- case DataType::DT_STRING:
- return Type::Class("String");
-
- case DataType::DT_FLOAT:
- return Type::Class("Float");
-
- case DataType::DT_DOUBLE:
- return Type::Class("Double");
-
- case DataType::DT_UINT8:
- return Type::Class("UInt8", "org.tensorflow.types");
-
- case DataType::DT_INT32:
- return Type::Class("Integer");
-
- case DataType::DT_INT64:
- return Type::Class("Long");
-
- case DataType::DT_RESOURCE:
- // TODO(karllessard) create a Resource utility class that could be
- // used to store a resource and its type (passed in a second argument).
- // For now, we need to force a wildcard and we will unfortunately lose
- // track of the resource type.
- return Type::Wildcard();
-
- default:
- break;
- }
- } else {
- // resolve type from type attribute
- string attr_name = arg.type_attr();
- if (attr_name.empty()) {
- attr_name = arg.type_list_attr();
- if (!attr_name.empty()) {
- *iterable_out = true;
- Type type = Type::Wildcard();
- visited_attrs_.insert(std::make_pair(attr_name, type));
- return type;
- }
- }
- for (const auto& attr : op_def_.attr()) {
- if (attr.name() == attr_name) {
- Type type = DataTypeOf(attr, iterable_out);
- visited_attrs_.insert(std::make_pair(attr_name, type));
- return type;
- }
- }
- }
- LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
- return Type::Wildcard();
-}
-
-Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
- std::map<string, Type>::const_iterator it = visited_attrs_.find(attr.name());
- if (it != visited_attrs_.cend()) {
- return it->second;
- }
- string attr_type = attr.type();
- if (attr.type().compare(0, 5, "list(") == 0) {
- attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
- *iterable_out = true;
- }
- if (attr_type == "type") {
- if (*iterable_out) {
- return Type::Enum("DataType", "org.tensorflow");
- }
- return GetNextGenericTensorType(attr.allowed_values());
- }
- if (attr_type == "string") {
- return Type::Class("String");
- }
- if (attr_type == "int") {
- return Type::Class("Integer");
- }
- if (attr_type == "float") {
- return Type::Class("Float");
- }
- if (attr_type == "bool") {
- return Type::Class("Boolean");
- }
- if (attr_type == "shape") {
- return Type::Class("Shape", "org.tensorflow");
- }
- if (attr_type == "tensor") {
- return Type::Class("Tensor", "org.tensorflow")
- .add_parameter(Type::Wildcard());
- }
- LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
- return *iterable_out ? Type::Wildcard() : Type::Class("Object");
-}
-
-Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values) {
- Type generic = Type::Generic(string(1, next_generic_));
- next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
-
- // when only real numbers are allowed, enforce that restriction in the Java by
- // extending the generic from java.lang.Number
- if (IsRealNumbers(allowed_values)) {
- generic.add_supertype(Type::Class("Number"));
- }
- return generic;
-}
-
-} // namespace java
-} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
deleted file mode 100644
index 42855127cc..0000000000
--- a/tensorflow/java/src/gen/cc/op_parser.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
-#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
-
-#include <map>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/framework/api_def.pb.h"
-#include "tensorflow/java/src/gen/cc/java_defs.h"
-
-namespace tensorflow {
-namespace java {
-
-// Specification of a TensorFlow operation to generate.
-//
-// This is the result of an operation definition parsing, see OpParser::Parse().
-class OpSpec {
- public:
- class Endpoint {
- public:
- Endpoint(const Type& type, const Javadoc& javadoc)
- : type_(type), javadoc_(javadoc) {}
- const Type& type() const { return type_; }
- const Javadoc& javadoc() const { return javadoc_; }
-
- private:
- Type type_;
- Javadoc javadoc_;
- };
-
- class Operand {
- public:
- Operand(const string& graph_name, const Variable& var,
- const Type& data_type, const string& description, bool iterable)
- : graph_name_(graph_name), var_(var), data_type_(data_type),
- description_(description), iterable_(iterable) {}
- const string& graph_name() const { return graph_name_; }
- const Variable& var() const { return var_; }
- Variable* var_ptr() { return &var_; }
- const Type& data_type() const { return data_type_; }
- const string& description() const { return description_; }
- bool iterable() const { return iterable_; }
-
- private:
- string graph_name_;
- Variable var_;
- Type data_type_;
- string description_;
- bool iterable_;
- };
-
- explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
- const string& graph_name() const { return graph_name_; }
- const std::vector<Endpoint> endpoints() const { return endpoints_; }
- void add_endpoint(const Type& type, const Javadoc& javadoc) {
- endpoints_.push_back(Endpoint(type, javadoc));
- }
- const std::vector<Operand>& inputs() const { return inputs_; }
- void add_input(const Operand& input) {
- inputs_.push_back(input);
- }
- const std::vector<Operand>& outputs() const { return outputs_; }
- void add_output(const Operand& output) {
- outputs_.push_back(output);
- }
- const std::vector<Operand>& attributes() const { return attributes_; }
- void add_attribute(const Operand& attribute) {
- attributes_.push_back(attribute);
- }
- const std::vector<Operand>& options() const { return options_; }
- void add_option(const Operand& option) {
- options_.push_back(option);
- }
-
- private:
- string graph_name_;
- std::vector<Endpoint> endpoints_;
- std::vector<Operand> inputs_;
- std::vector<Operand> outputs_;
- std::vector<Operand> attributes_;
- std::vector<Operand> options_;
-};
-
-// A parser of ops proto definitions.
-//
-// This object parses the definition and the api of an TensorFlow operation to
-// produce a specification that can be used for Java source code rendering.
-class OpParser {
- public:
- OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
- const string& base_package);
- virtual ~OpParser() = default;
-
- // Produces an operation specification from its proto definitions.
- void Parse(std::unique_ptr<OpSpec>* op_ptr);
-
- private:
- OpDef op_def_;
- ApiDef op_api_;
- string lib_name_;
- string base_package_;
- std::map<string, Type> visited_attrs_;
- char next_generic_ = 0;
-
- void BuildEndpoints(OpSpec* op);
- void ParseInput(const OpDef_ArgDef& input_def,
- const ApiDef::Arg& input_api, OpSpec* op);
- void ParseOutput(const OpDef_ArgDef& output_def,
- const ApiDef::Arg& output_api, OpSpec* op);
- void ParseAttribute(const OpDef_AttrDef& attr_def,
- const ApiDef::Attr& attr_api, OpSpec* op);
- Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
- Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
- Type GetNextGenericTensorType(const AttrValue& allowed_values);
-};
-
-} // namespace java
-} // namespace tensorflow
-
-#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
new file mode 100644
index 0000000000..a727f7ae90
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -0,0 +1,390 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <vector>
+#include <string>
+#include <utility>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+inline bool IsRealNumbers(const AttrValue& values) {
+ if (!values.has_list()) {
+ return RealNumberTypes().Contains(values.type());
+ }
+ for (int i = 0; i < values.list().type_size(); ++i) {
+ if (!RealNumberTypes().Contains(values.list().type(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+class TypeResolver {
+ public:
+ explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
+
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+ Type TypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
+ bool IsAttributeVisited(const string& attr_name) {
+ return visited_attrs_.find(attr_name) != visited_attrs_.cend();
+ }
+ private:
+ const OpDef op_def_;
+ std::map<std::string, Type> visited_attrs_;
+ char next_generic_ = 'T';
+};
+
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
+ bool* iterable_out) {
+ *iterable_out = false;
+ if (!arg_def.number_attr().empty()) {
+ // when number_attr is set, argument has to be a list of tensors
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
+ }
+ Type type = Type::Wildcard();
+ if (arg_def.type() != DataType::DT_INVALID) {
+ // resolve type from DataType
+ switch (arg_def.type()) {
+ case DataType::DT_BOOL:
+ type = Type::Class("Boolean");
+ break;
+ case DataType::DT_STRING:
+ type = Type::Class("String");
+ break;
+ case DataType::DT_FLOAT:
+ type = Type::Class("Float");
+ break;
+ case DataType::DT_DOUBLE:
+ type = Type::Class("Double");
+ break;
+ case DataType::DT_UINT8:
+ type = Type::Class("UInt8", "org.tensorflow.types");
+ break;
+ case DataType::DT_INT32:
+ type = Type::Class("Integer");
+ break;
+ case DataType::DT_INT64:
+ type = Type::Class("Long");
+ break;
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ break;
+ default:
+ // Any other datatypes does not have a equivalent in Java and must
+ // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+ break;
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ // resolve type from attribute (if already visited, retrieve its type)
+ if (IsAttributeVisited(arg_def.type_attr())) {
+ type = visited_attrs_.at(arg_def.type_attr());
+ } else {
+ for (const auto& attr_def : op_def_.attr()) {
+ if (attr_def.name() == arg_def.type_attr()) {
+ type = TypeOf(attr_def, iterable_out);
+ break;
+ }
+ }
+ }
+ } else if (!arg_def.type_list_attr().empty()) {
+ // type is a list of tensors that can be of different data types, so leave
+ // it as a list of wildcards
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ return type;
+}
+
+Type TypeResolver::TypeOf(const OpDef_AttrDef& attr_def,
+ bool* iterable_out) {
+ *iterable_out = false;
+ StringPiece attr_type = attr_def.type();
+ if (str_util::ConsumePrefix(&attr_type, "list(")) {
+ attr_type.remove_suffix(1); // remove closing brace
+ *iterable_out = true;
+ }
+ Type type = *iterable_out ? Type::Wildcard() : Type::Class("Object");
+ if (attr_type == "type") {
+ if (*iterable_out) {
+ type = Type::Enum("DataType", "org.tensorflow");
+ } else {
+ type = Type::Generic(string(1, next_generic_));
+ next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
+ if (IsRealNumbers(attr_def.allowed_values())) {
+ // enforce real numbers datasets by extending java.lang.Number
+ type.add_supertype(Type::Class("Number"));
+ }
+ }
+ } else if (attr_type == "string") {
+ type = Type::Class("String");
+
+ } else if (attr_type == "int") {
+ type = Type::Class("Integer");
+
+ } else if (attr_type == "float") {
+ type = Type::Class("Float");
+
+ } else if (attr_type == "bool") {
+ type = Type::Class("Boolean");
+
+ } else if (attr_type == "shape") {
+ type = Type::Class("Shape", "org.tensorflow");
+
+ } else if (attr_type == "tensor") {
+ type = Type::Class("Tensor", "org.tensorflow")
+ .add_parameter(Type::Wildcard());
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ visited_attrs_.insert(std::make_pair(attr_def.name(), type));
+ return type;
+}
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+ string result;
+ bool cap = upper;
+ for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+ const char c = *it;
+ if (c == '_') {
+ cap = true;
+ } else if (cap) {
+ result += toupper(c);
+ cap = false;
+ } else {
+ result += c;
+ }
+ }
+ return result;
+}
+
+bool FindAndCut(re2::StringPiece* input, const RE2& expr,
+ re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
+ re2::StringPiece match;
+ bool matches =
+ expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1);
+ if (matches) {
+ before_match->set(input->data(), match.begin() - input->begin());
+ input->remove_prefix(match.end() - before_match->begin());
+ if (ret_match != nullptr) {
+ *ret_match = match;
+ }
+ } else {
+ *before_match = *input;
+ if (ret_match != nullptr) {
+ ret_match->set(nullptr, 0);
+ }
+ }
+ return matches;
+}
+
+string ParseDocumentation(const string& mdtext) {
+ std::stringstream javadoc_text;
+ re2::StringPiece input(mdtext);
+ re2::StringPiece text;
+ bool in_list = false;
+ do {
+ re2::StringPiece markup;
+ FindAndCut(&input,
+ "\n+\\*[[:blank:]]+|\n{2,}|`{3,}|`{1,2}|\\*{1,2}\\b|\\[",
+ &text, &markup);
+ javadoc_text << text;
+ if (markup.empty()) {
+ break; // we are done parsing
+ }
+ if (markup.starts_with("\n")) {
+ javadoc_text << "\n";
+ if (markup.contains("* ")) {
+ javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
+ in_list = true;
+ } else if (markup.starts_with("\n\n")) {
+ if (in_list) {
+ javadoc_text << "</li>\n</ul>\n";
+ in_list = false;
+ } else if (!input.starts_with("```")) {
+ javadoc_text << "<p>\n";
+ }
+ }
+ } else if (markup.starts_with("```") && text.empty()) {
+ re2::StringPiece language;
+ RE2::Consume(&input, "[\\w\\+]+", &language);
+ if (FindAndCut(&input, markup.ToString() + "\n*", &text)) {
+ javadoc_text << "<pre>\n{@code" << text << "}\n</pre>\n";
+ } else {
+ javadoc_text << markup << language;
+ }
+ } else if (markup.starts_with("`")) {
+ if (FindAndCut(&input, markup, &text)) {
+ javadoc_text << "{@code " << text << "}";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "**") {
+ if (FindAndCut(&input, "\\b\\*{2}", &text)) {
+ javadoc_text << "<b>" << text << "</b>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "*") {
+ if (FindAndCut(&input, "\\b\\*{1}", &text)) {
+ javadoc_text << "<i>" << text << "</i>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "[") {
+ string label;
+ string link;
+ if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) {
+ javadoc_text << "<a href=\"" << link << "\">" << label << "</a>";
+ } else {
+ javadoc_text << markup;
+ }
+ }
+ } while (!input.empty());
+
+ return javadoc_text.str();
+}
+
+ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
+ const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(input_def, &iterable);
+ Type var_type = Type::Interface("Operand", "org.tensorflow")
+ .add_parameter(type);
+ if (iterable) {
+ var_type = Type::IterableOf(var_type);
+ }
+ return ArgumentSpec(input_api_def.name(),
+ Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
+ type,
+ ParseDocumentation(input_api_def.description()),
+ iterable);
+}
+
+AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
+ const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(attr_def, &iterable);
+ // type attributes must be passed explicitly in methods as a Class<> parameter
+ bool is_explicit = type.kind() == Type::GENERIC && !iterable;
+ Type var_type = is_explicit ? Type::Class("Class").add_parameter(type) : type;
+ if (iterable) {
+ var_type = Type::ListOf(type);
+ }
+ return AttributeSpec(attr_api_def.name(),
+ Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
+ type,
+ ParseDocumentation(attr_api_def.description()),
+ iterable,
+ attr_api_def.has_default_value() && !is_explicit);
+}
+
+ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
+ const ApiDef::Arg& output_api, TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(output_def, &iterable);
+ Type var_type = Type::Class("Output", "org.tensorflow")
+ .add_parameter(type);
+ if (iterable) {
+ var_type = Type::ListOf(var_type);
+ }
+ return ArgumentSpec(output_api.name(),
+ Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
+ type,
+ ParseDocumentation(output_api.description()),
+ iterable);
+}
+
+EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
+ const ApiDef_Endpoint& endpoint_def) {
+
+ std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+ string package;
+ string name;
+ if (name_tokens.size() > 1) {
+ package = name_tokens.at(0);
+ name = name_tokens.at(1);
+ } else {
+ package = "core"; // generate unclassified ops in the 'core' package
+ name = name_tokens.at(0);
+ }
+ return EndpointSpec(package,
+ name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())),
+ endpoint_def.deprecation_version() > 0);
+}
+
+} // namespace
+
+OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
+ OpSpec op(api_def.graph_op_name(),
+ api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
+ TypeResolver type_resolver(op_def);
+ for (const string& next_input_name : api_def.arg_order()) {
+ for (int i = 0; i < op_def.input_arg().size(); ++i) {
+ if (op_def.input_arg(i).name() == next_input_name) {
+ op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
+ &type_resolver));
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < op_def.attr().size(); ++i) {
+ // do not parse attributes already visited, they have probably been inferred
+ // before as an input argument type
+ if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
+ AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
+ &type_resolver);
+ // attributes with a default value are optional
+ if (attr.optional()) {
+ op.optional_attributes_.push_back(attr);
+ } else {
+ op.attributes_.push_back(attr);
+ }
+ }
+ }
+ for (int i = 0; i < op_def.output_arg().size(); ++i) {
+ op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i),
+ &type_resolver));
+ }
+ for (const auto& endpoint_def : api_def.endpoint()) {
+ op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
+ }
+ return op;
+}
+
+} // namespace java
+} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
new file mode 100644
index 0000000000..55c2c3f307
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -0,0 +1,152 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+
+namespace tensorflow {
+namespace java {
+
+class EndpointSpec {
+ public:
+ // A specification for an operation endpoint
+ //
+ // package: package of this endpoint (from which also derives its package)
+ // name: name of this endpoint class
+ // javadoc: the endpoint class documentation
+ // deprecated: true if this endpoint is now deprecated
+ EndpointSpec(const string& package, const string& name,
+ const Javadoc& javadoc, bool deprecated)
+ : package_(package), name_(name), javadoc_(javadoc),
+ deprecated_(deprecated) {}
+
+ const string& package() const { return package_; }
+ const string& name() const { return name_; }
+ const Javadoc& javadoc() const { return javadoc_; }
+ bool deprecated() const { return deprecated_; }
+
+ private:
+ const string package_;
+ const string name_;
+ const Javadoc javadoc_;
+ const bool deprecated_;
+};
+
+class ArgumentSpec {
+ public:
+ // A specification for an operation argument
+ //
+ // op_def_name: argument name, as known by TensorFlow core
+ // var: a variable to represent this argument in Java
+ // type: the tensor type of this argument
+ // description: a description of this argument, in javadoc
+ // iterable: true if this argument is a list
+ ArgumentSpec(const string& op_def_name, const Variable& var,
+ const Type& type, const string& description, bool iterable)
+ : op_def_name_(op_def_name), var_(var), type_(type),
+ description_(description), iterable_(iterable) {}
+ virtual ~ArgumentSpec() = default;
+
+ const string& op_def_name() const { return op_def_name_; }
+ const Variable& var() const { return var_; }
+ const Type& type() const { return type_; }
+ const string& description() const { return description_; }
+ bool iterable() const { return iterable_; }
+
+ private:
+ const string op_def_name_;
+ const Variable var_;
+ const Type type_;
+ const string description_;
+ const bool iterable_;
+};
+
+class AttributeSpec : public ArgumentSpec {
+ public:
+ // A specification for an operation attribute
+ //
+ // op_def_name: attribute name, as known by TensorFlow core
+ // var: a variable to represent this attribute in Java
+ // type: the type of this attribute
+ // description: a description of this attribute, in javadoc
+ // iterable: true if this attribute is a list
+ // optional: true if this attribute does not require to be set explicitly
+ AttributeSpec(const string& op_def_name, const Variable& var,
+ const Type& type, const string& description, bool iterable,
+ bool optional)
+ : ArgumentSpec(op_def_name, var, type, description, iterable),
+ optional_(optional) {}
+ virtual ~AttributeSpec() = default;
+
+ bool optional() const { return optional_; }
+
+ private:
+ const bool optional_;
+};
+
+class OpSpec {
+ public:
+ // Parses an op definition and its API to produce a specification used for
+ // rendering its Java wrapper
+ //
+ // op_def: Op definition
+ // api_def: Op API definition
+ static OpSpec Create(const OpDef& op_def, const ApiDef& api_def);
+
+ const string& graph_op_name() const { return graph_op_name_; }
+ bool hidden() const { return hidden_; }
+ const string& deprecation_explanation() const {
+ return deprecation_explanation_;
+ }
+ const std::vector<EndpointSpec> endpoints() const { return endpoints_; }
+ const std::vector<ArgumentSpec>& inputs() const { return inputs_; }
+ const std::vector<ArgumentSpec>& outputs() const { return outputs_; }
+ const std::vector<AttributeSpec>& attributes() const { return attributes_; }
+ const std::vector<AttributeSpec>& optional_attributes() const {
+ return optional_attributes_;
+ }
+
+ private:
+ // A specification for an operation
+ //
+ // graph_op_name: name of this op, as known by TensorFlow core engine
+ // hidden: true if this op should not be visible through the Graph Ops API
+ // deprecation_explanation: message to show if all endpoints are deprecated
+ explicit OpSpec(const string& graph_op_name, bool hidden,
+ const string& deprecation_explanation)
+ : graph_op_name_(graph_op_name), hidden_(hidden),
+ deprecation_explanation_(deprecation_explanation) {}
+
+ const string graph_op_name_;
+ const bool hidden_;
+ const string deprecation_explanation_;
+ std::vector<EndpointSpec> endpoints_;
+ std::vector<ArgumentSpec> inputs_;
+ std::vector<ArgumentSpec> outputs_;
+ std::vector<AttributeSpec> attributes_;
+ std::vector<AttributeSpec> optional_attributes_;
+};
+
+} // namespace java
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index b1de5af6ba..7e427787f9 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h
index 1f0febe9a3..bcae33ccce 100644
--- a/tensorflow/java/src/gen/cc/source_writer.h
+++ b/tensorflow/java/src/gen/cc/source_writer.h
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc
index 8bd42d9d0e..875ad99ae2 100644
--- a/tensorflow/java/src/gen/cc/source_writer_test.cc
+++ b/tensorflow/java/src/gen/cc/source_writer_test.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -313,8 +313,7 @@ TEST(WriteType, SimpleClassWithDependencies) {
TEST(WriteType, AnnotatedAndDocumentedClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
- Javadoc clazz_doc;
- clazz_doc.brief("Javadoc test")
+ Javadoc clazz_doc = Javadoc::Create("Javadoc test")
.details("This is a\nmultiline description.");
clazz.add_annotation(Annotation::Create("Bean"));
clazz.add_annotation(Annotation::Create("SuppressWarnings")
@@ -329,7 +328,7 @@ TEST(WriteType, AnnotatedAndDocumentedClass) {
" * <p>\n"
" * This is a\n"
" * multiline description.\n"
- " **/\n"
+ " */\n"
"@Bean\n"
"@SuppressWarnings(\"rawtypes\")\n"
"public class Test {\n}\n";
@@ -378,8 +377,7 @@ TEST(WriteType, ParameterizedClassFields) {
Variable field1 = Variable::Create("field1", Type::Class("String"));
Variable field2 = Variable::Create("field2", Type::Class("String"));
Variable field3 = Variable::Create("field3", type_t);
- Javadoc field3_doc;
- field3_doc.brief("This variable is documented");
+ Javadoc field3_doc = Javadoc::Create("This variable is documented");
writer.BeginType(clazz, PUBLIC)
.WriteField(field1, STATIC | PUBLIC | FINAL)
@@ -464,8 +462,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- Javadoc method_doc;
- method_doc.brief("Javadoc test")
+ Javadoc method_doc = Javadoc::Create("Javadoc test")
.details("This method has a\nmultiline description.");
method.add_annotation(Annotation::Create("Override"));
method.add_annotation(Annotation::Create("SuppressWarnings")
@@ -484,7 +481,7 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
" * <p>\n"
" * This method has a\n"
" * multiline description.\n"
- " **/\n"
+ " */\n"
" @Override\n"
" @SuppressWarnings(\"rawtypes\")\n"
" public void doNothing() {\n"
@@ -500,8 +497,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
Method method = Method::Create("boolToInt", Type::Int());
method.add_argument(Variable::Create("b", Type::Boolean()));
method.add_argument(reverse);
- Javadoc method_doc;
- method_doc.brief("Converts a boolean to an int")
+ Javadoc method_doc = Javadoc::Create("Converts a boolean to an int")
.details("This method will convert\na boolean to an int")
.add_param_tag(reverse.name(), "if true, value is reversed")
.add_tag("return", "int value for this boolean");
@@ -528,7 +524,7 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
" * \n"
" * @param reverse if true, value is reversed\n"
" * @return int value for this boolean\n"
- " **/\n"
+ " */\n"
" public int boolToInt(boolean b, boolean reverse) {\n"
" if (b && !reverse) {\n"
" return 1;\n"
diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl
index 1e7899cf7a..7017b52649 100644
--- a/tensorflow/java/src/gen/gen_ops.bzl
+++ b/tensorflow/java/src/gen/gen_ops.bzl
@@ -32,50 +32,52 @@ def tf_java_op_gen_srcjar(name,
api_def_srcs=[],
visibility=["//tensorflow/java:__pkg__"]):
- gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
srcs = api_def_srcs[:]
- # Construct an op generator binary for each ops library.
- for ops_lib in ops_libs:
- gen_lib = ops_lib[:ops_lib.rfind("_")]
- out_gen_tool = out_dir + ops_lib + "_gen_tool"
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ "$$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))")
+ api_def_args_str = ",".join(api_def_args)
- if not api_def_srcs:
- api_def_args_str = ","
- else:
- api_def_args = []
- for api_def_src in api_def_srcs:
- # Add directory of the first ApiDef source to args.
- # We are assuming all ApiDefs in a single api_def_src are in the
- # same directory.
- api_def_args.append(
- " $$(dirname $$(echo $(locations " + api_def_src +
- ") | cut -d\" \" -f1))")
- api_def_args_str = ",".join(api_def_args)
+ gen_tool_deps = [":java_op_gen_lib"]
+ for ops_lib in ops_libs:
+ gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib")
- tf_cc_binary(
- name=out_gen_tool,
- copts=tf_copts(),
- linkopts=["-lm"],
- linkstatic=1, # Faster to link this one-time-use binary dynamically
- deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
+ tf_cc_binary(
+ name=gen_tool,
+ srcs=[
+ "src/gen/cc/op_gen_main.cc",
+ ],
+ copts=tf_copts(),
+ linkopts=["-lm"],
+ linkstatic=1, # Faster to link this one-time-use binary dynamically
+ deps = gen_tool_deps)
- gen_tools += [":" + out_gen_tool]
- gen_cmds += ["$(location :" + out_gen_tool + ")" +
- " --output_dir=$(@D)/" + out_src_dir +
- " --lib_name=" + gen_lib +
- " --base_package=" + gen_base_package +
- " " + api_def_args_str]
+ gen_cmds += ["$(location :" + gen_tool + ")" +
+ " --output_dir=$(@D)/" + out_src_dir +
+ " --base_package=" + gen_base_package +
+ " --api_dirs=" + api_def_args_str]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
- gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
- gen_tools += tf_binary_additional_srcs()
+
native.genrule(
name=name,
srcs=srcs,
outs=[gen_srcjar],
- tools=gen_tools,
- cmd="&&".join(gen_cmds))
+ tools=[
+ "@local_jdk//:jar",
+ "@local_jdk//:jdk",
+ gen_tool
+ ] + tf_binary_additional_srcs(),
+ cmd=" && ".join(gen_cmds))
diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.java.snippet
index 90285ec669..90285ec669 100644
--- a/tensorflow/java/src/gen/resources/license.snippet.java
+++ b/tensorflow/java/src/gen/resources/license.java.snippet