aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-05-29 00:50:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 00:53:00 -0700
commit8ecf1ebc5d83e66b29a07113b53c49ef8264703c (patch)
tree5b6c00f7ae5f1beeae52482bb5411057d4735fff /tensorflow/java
parent23c70a240219da7201806e4ebc1354f568581557 (diff)
Re-apply #18192.
PiperOrigin-RevId: 198358055
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD41
-rw-r--r--tensorflow/java/build_defs.bzl1
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h80
-rw-r--r--tensorflow/java/src/gen/cc/op_gen_main.cc46
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc464
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h37
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc423
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h173
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc142
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.h55
-rw-r--r--tensorflow/java/src/gen/cc/source_writer_test.cc155
-rw-r--r--tensorflow/java/src/gen/gen_ops.bzl74
12 files changed, 1363 insertions, 328 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 78596d147a..19d2133a55 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -60,9 +60,7 @@ java_library(
filegroup(
name = "java_op_sources",
- srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [
- ":java_op_gen_sources",
- ],
+ srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]) + [":java_op_gen_sources"],
visibility = [
"//tensorflow/java:__pkg__",
],
@@ -70,43 +68,27 @@ filegroup(
tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
- gen_base_package = "org.tensorflow.op",
- gen_tool = "java_op_gen_tool",
- ops_libs = [
- "array_ops",
- "candidate_sampling_ops",
- "control_flow_ops",
- "data_flow_ops",
- "image_ops",
- "io_ops",
- "linalg_ops",
- "logging_ops",
- "math_ops",
- "nn_ops",
- "no_op",
- "parsing_ops",
- "random_ops",
- "sparse_ops",
- "state_ops",
- "string_ops",
- "training_ops",
- "user_ops",
+ api_def_srcs = [
+ "//tensorflow/core/api_def:base_api_def",
],
+ base_package = "org.tensorflow.op",
+ gen_tool = ":java_op_gen_tool",
)
-# Build the gen tool as a library, as it will be linked to a core/ops binary
-# file before making it an executable. See tf_java_op_gen_srcjar().
-cc_library(
+tf_cc_binary(
name = "java_op_gen_tool",
srcs = [
"src/gen/cc/op_gen_main.cc",
],
copts = tf_copts(),
+ linkopts = ["-lm"],
+ linkstatic = 1,
deps = [
":java_op_gen_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
+ "//tensorflow/core:ops",
],
)
@@ -114,11 +96,13 @@ cc_library(
name = "java_op_gen_lib",
srcs = [
"src/gen/cc/op_generator.cc",
+ "src/gen/cc/op_specs.cc",
"src/gen/cc/source_writer.cc",
],
hdrs = [
"src/gen/cc/java_defs.h",
"src/gen/cc/op_generator.h",
+ "src/gen/cc/op_specs.h",
"src/gen/cc/source_writer.h",
],
copts = tf_copts(),
@@ -127,6 +111,9 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
+ "//tensorflow/core:protos_all_cc",
+ "@com_googlesource_code_re2//:re2",
],
)
diff --git a/tensorflow/java/build_defs.bzl b/tensorflow/java/build_defs.bzl
index ab7f60d03d..e1916ca4d9 100644
--- a/tensorflow/java/build_defs.bzl
+++ b/tensorflow/java/build_defs.bzl
@@ -15,6 +15,7 @@ JAVA_VERSION_OPTS = [
XLINT_OPTS = [
"-Werror",
"-Xlint:all",
+ "-Xlint:-processing",
"-Xlint:-serial",
"-Xlint:-try",
"-Xlint:-classfile", # see b/32750402, go/javac-warnings#classfile
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index 59f8beaee7..d9d6f8adc8 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -16,19 +16,22 @@ limitations under the License.
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
-#include <string>
#include <list>
+#include <map>
+#include <string>
+#include <utility>
namespace tensorflow {
namespace java {
// An enumeration of different modifiers commonly used in Java
enum Modifier {
- PUBLIC = (1 << 0),
+ PACKAGE = 0,
+ PUBLIC = (1 << 0),
PROTECTED = (1 << 1),
- PRIVATE = (1 << 2),
- STATIC = (1 << 3),
- FINAL = (1 << 4),
+ PRIVATE = (1 << 2),
+ STATIC = (1 << 3),
+ FINAL = (1 << 4),
};
class Annotation;
@@ -72,6 +75,8 @@ class Type {
// Reflection API does
return Type(Type::PRIMITIVE, "void");
}
+ static Type Generic(const string& name) { return Type(Type::GENERIC, name); }
+ static Type Wildcard() { return Type(Type::GENERIC, ""); }
static Type Class(const string& name, const string& package = "") {
return Type(Type::CLASS, name, package);
}
@@ -81,9 +86,6 @@ class Type {
static Type Enum(const string& name, const string& package = "") {
return Type(Type::ENUM, name, package);
}
- static Type Generic(const string& name = "") {
- return Type(Type::GENERIC, name);
- }
static Type ClassOf(const Type& type) {
return Class("Class").add_parameter(type);
}
@@ -96,11 +98,10 @@ class Type {
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
- const string& description() const { return description_; }
- Type& description(const string& description) {
- description_ = description;
- return *this;
+ const string canonical_name() const {
+ return package_.empty() ? name_ : package_ + "." + name_;
}
+ bool wildcard() const { return name_.empty(); } // only wildcards has no name
const std::list<Type>& parameters() const { return parameters_; }
Type& add_parameter(const Type& parameter) {
parameters_.push_back(parameter);
@@ -120,14 +121,6 @@ class Type {
}
return *this;
}
- // Returns true if "type" is of a known collection type (only a few for now)
- bool IsCollection() const {
- return name_ == "List" || name_ == "Iterable";
- }
- // Returns true if this instance is a wildcard (<?>)
- bool IsWildcard() const {
- return kind_ == GENERIC && name_.empty();
- }
protected:
Type(Kind kind, const string& name, const string& package = "")
@@ -137,7 +130,6 @@ class Type {
Kind kind_;
string name_;
string package_;
- string description_;
std::list<Type> parameters_;
std::list<Annotation> annotations_;
std::list<Type> supertypes_;
@@ -180,16 +172,11 @@ class Variable {
const string& name() const { return name_; }
const Type& type() const { return type_; }
bool variadic() const { return variadic_; }
- const string& description() const { return description_; }
- Variable& description(const string& description) {
- description_ = description;
- return *this;
- }
+
private:
string name_;
Type type_;
bool variadic_;
- string description_;
Variable(const string& name, const Type& type, bool variadic)
: name_(name), type_(type), variadic_(variadic) {}
@@ -210,16 +197,6 @@ class Method {
bool constructor() const { return constructor_; }
const string& name() const { return name_; }
const Type& return_type() const { return return_type_; }
- const string& description() const { return description_; }
- Method& description(const string& description) {
- description_ = description;
- return *this;
- }
- const string& return_description() const { return return_description_; }
- Method& return_description(const string& description) {
- return_description_ = description;
- return *this;
- }
const std::list<Variable>& arguments() const { return arguments_; }
Method& add_argument(const Variable& var) {
arguments_.push_back(var);
@@ -235,8 +212,6 @@ class Method {
string name_;
Type return_type_;
bool constructor_;
- string description_;
- string return_description_;
std::list<Variable> arguments_;
std::list<Annotation> annotations_;
@@ -244,6 +219,33 @@ class Method {
: name_(name), return_type_(return_type), constructor_(constructor) {}
};
+// A definition of a documentation bloc for a Java element (JavaDoc)
+class Javadoc {
+ public:
+ static Javadoc Create(const string& brief = "") { return Javadoc(brief); }
+ const string& brief() const { return brief_; }
+ const string& details() const { return details_; }
+ Javadoc& details(const string& details) {
+ details_ = details;
+ return *this;
+ }
+ const std::list<std::pair<string, string>>& tags() const { return tags_; }
+ Javadoc& add_tag(const string& tag, const string& text) {
+ tags_.push_back(std::make_pair(tag, text));
+ return *this;
+ }
+ Javadoc& add_param_tag(const string& name, const string& text) {
+ return add_tag("param", name + " " + text);
+ }
+
+ private:
+ string brief_;
+ string details_;
+ std::list<std::pair<string, string>> tags_;
+
+ explicit Javadoc(const string& brief) : brief_(brief) {}
+};
+
} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc
index bea99f3d7f..0d9e0883af 100644
--- a/tensorflow/java/src/gen/cc/op_gen_main.cc
+++ b/tensorflow/java/src/gen/cc/op_gen_main.cc
@@ -36,49 +36,43 @@ const char kUsageHeader[] =
"Operation wrappers are generated under the path specified by the "
"'--output_dir' argument. This path can be absolute or relative to the\n"
"current working directory and will be created if it does not exists.\n\n"
- "The '--lib_name' argument is used to classify the set of operations. If "
- "the chosen name contains more than one word, it must be provided in \n"
- "snake_case. This value is declined into other meaningful names, such as "
- "the group and package of the generated operations. For example,\n"
- "'--lib_name=my_lib' generates the operations under the "
- "'org.tensorflow.op.mylib' package and add them to the 'myLib()' operator\n"
- "group.\n\n"
- "Note that the operator group assigned to the generated wrappers is just "
- "an annotation tag at this stage. Operations will not be available "
- "through\n"
- "the 'org.tensorflow.op.Ops' API as a group until the generated classes "
- "are compiled using an appropriate annotation processor.\n\n"
- "Finally, the '--base_package' overrides the default parent package "
- "under which the generated subpackage and classes are to be located.\n\n";
+ "Note that the operations will not be available through the "
+ "'org.tensorflow.op.Ops' API until the generated classes are compiled\n"
+ "using an appropriate annotation processor.\n\n"
+ "The '--base_package' overrides the default parent package under which "
+ "the generated subpackage and classes are to be located.\n\n"
+ "Finally, the `--api_dirs` argument takes a list of comma-separated "
+ "directories of API definitions can be provided to override default\n"
+ "values found in the ops definitions. Directories are ordered by priority "
+ "(the last having precedence over the first).\n\n";
} // namespace java
} // namespace tensorflow
int main(int argc, char* argv[]) {
- tensorflow::string lib_name;
tensorflow::string output_dir;
tensorflow::string base_package = "org.tensorflow.op";
+ tensorflow::string api_dirs_str;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
"Root directory into which output files are generated"),
tensorflow::Flag(
- "lib_name", &lib_name,
- "A name, in snake_case, used to classify this set of operations"),
- tensorflow::Flag(
"base_package", &base_package,
- "Package parent to the generated subpackage and classes")};
+ "Package parent to the generated subpackage and classes"),
+ tensorflow::Flag(
+ "api_dirs", &api_dirs_str,
+ "List of directories that contains the ops api definitions")};
tensorflow::string usage = tensorflow::java::kUsageHeader;
usage += tensorflow::Flags::Usage(argv[0], flag_list);
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
- QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
-
- tensorflow::java::OpGenerator generator;
+ QCHECK(parsed_flags_ok && !output_dir.empty()) << usage;
+ std::vector<tensorflow::string> api_dirs = tensorflow::str_util::Split(
+ api_dirs_str, ",", tensorflow::str_util::SkipEmpty());
+ tensorflow::java::OpGenerator generator(api_dirs);
tensorflow::OpList ops;
- tensorflow::OpRegistry::Global()->Export(true, &ops);
- tensorflow::Status status =
- generator.Run(ops, lib_name, base_package, output_dir);
- TF_QCHECK_OK(status);
+ tensorflow::OpRegistry::Global()->Export(false, &ops);
+ TF_CHECK_OK(generator.Run(ops, base_package, output_dir));
return 0;
}
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index def06baf2d..debd95fc62 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -13,54 +13,466 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <list>
+#include <map>
+#include <memory>
+#include <set>
#include <string>
+#include <vector>
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
#include "tensorflow/java/src/gen/cc/op_generator.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+#include "tensorflow/java/src/gen/cc/source_writer.h"
namespace tensorflow {
namespace java {
namespace {
-string CamelCase(const string& str, char delimiter, bool upper) {
- string result;
- bool cap = upper;
- for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
- const char c = *it;
- if (c == delimiter) {
- cap = true;
- } else if (cap) {
- result += toupper(c);
- cap = false;
+const char* kLicense =
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ "\n"
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ "you may not use this file except in compliance with the License.\n"
+ "You may obtain a copy of the License at\n"
+ "\n"
+ " http://www.apache.org/licenses/LICENSE-2.0\n"
+ "\n"
+ "Unless required by applicable law or agreed to in writing, software\n"
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ "See the License for the specific language governing permissions and\n"
+ "limitations under the License.\n"
+ "=======================================================================*/"
+ "\n";
+
+// There is three different modes to render an op class, depending on the
+// number and type of outputs it has:
+//
+// DEFAULT: This mode does not provide any specialization for the op class, it
+// is applied when the operation does not comply with any other mode
+//
+// OPERAND: The op class implements the Operand<T> interface, allowing an
+// instance to be passed directly in input to another operation
+//
+// LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface,
+// allowing an instance to be passed directly as a list input to
+// another operation
+//
+enum RenderMode { DEFAULT, OPERAND, LIST_OPERAND };
+
+void AddArgument(const Variable& var, const string& description,
+ Method* method_out, Javadoc* javadoc_out) {
+ method_out->add_argument(var);
+ javadoc_out->add_param_tag(var.name(), description);
+}
+
+void CollectOpDependencies(const OpSpec& op, RenderMode mode,
+ std::list<Type>* out) {
+ out->push_back(Type::Class("Operation", "org.tensorflow"));
+ out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
+ out->push_back(Type::Class("Scope", "org.tensorflow.op"));
+ if (mode == OPERAND) {
+ out->push_back(Type::Class("Output", "org.tensorflow"));
+ } else if (mode == LIST_OPERAND) {
+ out->push_back(Type::Interface("Iterator", "java.util"));
+ }
+ // Don't pay attention to duplicate types in the dependency list, they will
+ // be filtered out by the SourceWriter.
+ for (const ArgumentSpec& input : op.inputs()) {
+ out->push_back(input.var().type());
+ if (input.iterable()) {
+ out->push_back(Type::Class("Operands", "org.tensorflow.op"));
+ }
+ }
+ for (const ArgumentSpec& output : op.outputs()) {
+ out->push_back(output.var().type());
+ if (output.iterable()) {
+ out->push_back(Type::Class("Arrays", "java.util"));
+ }
+ }
+ for (const AttributeSpec& attribute : op.attributes()) {
+ out->push_back(attribute.var().type());
+ out->push_back(attribute.jni_type());
+ }
+ for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+ out->push_back(optional_attribute.var().type());
+ }
+}
+
+void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
+ SourceWriter* writer) {
+ string var_name = optional ? "opts." + attr.var().name() : attr.var().name();
+ if (attr.iterable()) {
+ string array_name = attr.var().name() + "Array";
+ writer->AppendType(attr.jni_type())
+ .Append("[] " + array_name + " = new ")
+ .AppendType(attr.jni_type())
+ .Append("[" + var_name + ".size()];")
+ .EndLine()
+ .BeginBlock("for (int i = 0; i < " + array_name + ".length; ++i)")
+ .Append(array_name + "[i] = ");
+ if (attr.type().kind() == Type::GENERIC) {
+ writer->Append("DataType.fromClass(" + var_name + ".get(i));");
+ } else {
+ writer->Append(var_name + ".get(i);");
+ }
+ writer->EndLine()
+ .EndBlock()
+ .Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ")
+ .Append(array_name + ");")
+ .EndLine();
+ } else {
+ writer->Append("opBuilder.setAttr(\"" + attr.op_def_name() + "\", ");
+ if (attr.var().type().name() == "Class") {
+ writer->Append("DataType.fromClass(" + var_name + "));");
} else {
- result += c;
+ writer->Append(var_name + ");");
}
+ writer->EndLine();
}
- return result;
}
-} // namespace
+void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
+ SourceWriter* writer) {
+ Method factory = Method::Create("create", op_class);
+ Javadoc factory_doc =
+ Javadoc::Create("Factory method to create a class to wrap a new " +
+ op_class.name() + " operation to the graph.");
+ Variable scope =
+ Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
+ AddArgument(scope, "current graph scope", &factory, &factory_doc);
+ for (const ArgumentSpec& input : op.inputs()) {
+ AddArgument(input.var(), input.description(), &factory, &factory_doc);
+ }
+ for (const AttributeSpec& attr : op.attributes()) {
+ AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+ }
+ if (!op.optional_attributes().empty()) {
+ AddArgument(Variable::Varargs("options", Type::Class("Options")),
+ "carries optional attributes values", &factory, &factory_doc);
+ }
+ factory_doc.add_tag("return", "a new instance of " + op_class.name());
-OpGenerator::OpGenerator() : env(Env::Default()) {}
+ writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
+ writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\"" +
+ op.graph_op_name() + "\", scope.makeOpName(\"" +
+ op_class.name() + "\"));");
+ writer->EndLine();
+ for (const ArgumentSpec& input : op.inputs()) {
+ if (input.iterable()) {
+ writer->Append("opBuilder.addInputList(Operands.asOutputs(" +
+ input.var().name() + "));");
+ writer->EndLine();
+ } else {
+ writer->Append("opBuilder.addInput(" + input.var().name() +
+ ".asOutput());");
+ writer->EndLine();
+ }
+ }
+ for (const AttributeSpec& attribute : op.attributes()) {
+ WriteSetAttrDirective(attribute, false, writer);
+ }
+ if (!op.optional_attributes().empty()) {
+ writer->BeginBlock("if (options != null)")
+ .BeginBlock("for (Options opts : options)");
+ for (const AttributeSpec& attribute : op.optional_attributes()) {
+ writer->BeginBlock("if (opts." + attribute.var().name() + " != null)");
+ WriteSetAttrDirective(attribute, true, writer);
+ writer->EndBlock();
+ }
+ writer->EndBlock().EndBlock();
+ }
+ writer->Append("return new ")
+ .AppendType(op_class)
+ .Append("(opBuilder.build());")
+ .EndLine();
+ writer->EndMethod();
+}
-OpGenerator::~OpGenerator() {}
+void RenderConstructor(const OpSpec& op, const Type& op_class,
+ SourceWriter* writer) {
+ Variable operation =
+ Variable::Create("operation", Type::Class("Operation", "org.tensorflow"));
+ Method constructor = Method::ConstructorFor(op_class).add_argument(operation);
+ for (const ArgumentSpec& output : op.outputs()) {
+ if (output.iterable() && !output.type().wildcard()) {
+ constructor.add_annotation(
+ Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+ break;
+ }
+ }
+ writer->BeginMethod(constructor, PRIVATE)
+ .Append("super(operation);")
+ .EndLine();
+ if (!op.outputs().empty()) {
+ writer->Append("int outputIdx = 0;").EndLine();
+ for (const ArgumentSpec& output : op.outputs()) {
+ if (output.iterable()) {
+ string var_length = output.var().name() + "Length";
+ writer->Append("int " + var_length)
+ .Append(" = operation.outputListLength(\"" + output.op_def_name() +
+ "\");")
+ .EndLine()
+ .Append(output.var().name() + " = Arrays.asList(");
+ if (!output.type().wildcard()) {
+ writer->Append("(")
+ .AppendType(output.var().type().parameters().front())
+ .Append("[])");
+ }
+ writer->Append("operation.outputList(outputIdx, " + var_length + "));")
+ .EndLine()
+ .Append("outputIdx += " + var_length + ";")
+ .EndLine();
+ } else {
+ writer
+ ->Append(output.var().name() + " = operation.output(outputIdx++);")
+ .EndLine();
+ }
+ }
+ }
+ writer->EndMethod();
+}
-Status OpGenerator::Run(const OpList& ops, const string& lib_name,
- const string& base_package, const string& output_dir) {
- const string package =
- base_package + '.' + str_util::StringReplace(lib_name, "_", "", true);
- const string package_path =
- output_dir + '/' + str_util::StringReplace(package, ".", "/", true);
- const string group = CamelCase(lib_name, '_', false);
+void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
+ for (const AttributeSpec& attr : op.optional_attributes()) {
+ Method setter = Method::Create(attr.var().name(), Type::Class("Options"));
+ Javadoc setter_doc = Javadoc::Create();
+ AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
+ writer->BeginMethod(setter, PUBLIC | STATIC, &setter_doc)
+ .Append("return new Options()." + attr.var().name() + "(" +
+ attr.var().name() + ");")
+ .EndLine()
+ .EndMethod();
+ }
+ for (const ArgumentSpec& output : op.outputs()) {
+ Method getter = Method::Create(output.var().name(), output.var().type());
+ Javadoc getter_doc = Javadoc::Create(output.description());
+ writer->BeginMethod(getter, PUBLIC, &getter_doc)
+ .Append("return " + output.var().name() + ";")
+ .EndLine()
+ .EndMethod();
+ }
+}
- if (!env->FileExists(package_path).ok()) {
- TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
+void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
+ SourceWriter* writer) {
+ ArgumentSpec output = op.outputs().front();
+
+ if (mode == OPERAND) {
+ bool cast2obj = output.type().wildcard();
+ Type return_type =
+ Type::Class("Output", "org.tensorflow")
+ .add_parameter(cast2obj ? Type::Class("Object") : output.type());
+ Method as_output = Method::Create("asOutput", return_type)
+ .add_annotation(Annotation::Create("Override"));
+ if (cast2obj) {
+ as_output.add_annotation(
+ Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+ }
+ writer->BeginMethod(as_output, PUBLIC);
+ if (cast2obj) {
+ writer->Append("return (").AppendType(return_type).Append(") ");
+ } else {
+ writer->Append("return ");
+ }
+ writer->Append(output.var().name() + ";").EndLine().EndMethod();
+
+ } else if (mode == LIST_OPERAND) {
+ Type operand = Type::Interface("Operand", "org.tensorflow");
+ if (output.type().wildcard()) {
+ operand.add_parameter(Type::Class("Object"));
+ } else {
+ operand.add_parameter(output.type());
+ }
+ Type return_type =
+ Type::Interface("Iterator", "java.util").add_parameter(operand);
+ Method iterator =
+ Method::Create("iterator", return_type)
+ .add_annotation(Annotation::Create("Override"))
+ .add_annotation(Annotation::Create("SuppressWarnings")
+ .attributes("{\"rawtypes\", \"unchecked\"}"));
+ // cast the output list using a raw List
+ writer->BeginMethod(iterator, PUBLIC)
+ .Append("return (" + return_type.name() + ") ")
+ .Append(output.var().name() + ".iterator();")
+ .EndLine()
+ .EndMethod();
+ }
+}
+
+void RenderOptionsClass(const OpSpec& op, const Type& op_class,
+ SourceWriter* writer) {
+ Type options_class = Type::Class("Options");
+ Javadoc options_doc = Javadoc::Create("Optional attributes for {@link " +
+ op_class.canonical_name() + "}");
+ writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
+ for (const AttributeSpec& attr : op.optional_attributes()) {
+ Method setter = Method::Create(attr.var().name(), options_class);
+ Javadoc setter_doc = Javadoc::Create();
+ AddArgument(attr.var(), attr.description(), &setter, &setter_doc);
+ writer->BeginMethod(setter, PUBLIC, &setter_doc)
+ .Append("this." + attr.var().name() + " = " + attr.var().name() + ";")
+ .EndLine()
+ .Append("return this;")
+ .EndLine()
+ .EndMethod();
+ }
+ writer->EndLine();
+ for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
+ writer->WriteField(optional_attribute.var(), PRIVATE);
}
+ Method constructor = Method::ConstructorFor(options_class);
+ writer->BeginMethod(constructor, PRIVATE).EndMethod();
+ writer->EndType();
+}
+
+inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) {
+ return Type::Class(
+ endpoint.name(),
+ base_package + "." + str_util::Lowercase(endpoint.package()));
+}
- LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
- // TODO(karllessard) generate wrappers from list of ops
+void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
+ const string& base_package, const string& output_dir,
+ Env* env) {
+ Type op_class(
+ ClassOf(endpoint, base_package)
+ .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op")));
+ Javadoc op_javadoc(endpoint.javadoc());
+ // op interfaces
+ RenderMode mode = DEFAULT;
+ if (op.outputs().size() == 1) {
+ const ArgumentSpec& output = op.outputs().front();
+ Type operand_type(output.type().wildcard() ? Type::Class("Object")
+ : output.type());
+ Type operand_inf(Type::Interface("Operand", "org.tensorflow")
+ .add_parameter(operand_type));
+ if (output.iterable()) {
+ mode = LIST_OPERAND;
+ op_class.add_supertype(Type::IterableOf(operand_inf));
+ } else {
+ mode = OPERAND;
+ op_class.add_supertype(operand_inf);
+ }
+ }
+ // op generic parameters
+ std::set<string> generics;
+ for (const ArgumentSpec& output : op.outputs()) {
+ if (output.type().kind() == Type::GENERIC && !output.type().wildcard() &&
+ generics.find(output.type().name()) == generics.end()) {
+ op_class.add_parameter(output.type());
+ op_javadoc.add_param_tag(
+ "<" + output.type().name() + ">",
+ "data type for {@code " + output.var().name() + "()} output");
+ generics.insert(output.type().name());
+ }
+ }
+ // op annotations
+ op_class.add_annotation(
+ Annotation::Create("Generated", "javax.annotation")
+ .attributes("value = \"TensorFlow Java Op Generator\""));
+ if (endpoint.deprecated()) {
+ op_class.add_annotation(Annotation::Create("Deprecated"));
+ string explanation;
+ if (!op.endpoints().front().deprecated()) {
+ explanation =
+ "use {@link " +
+ ClassOf(op.endpoints().front(), base_package).canonical_name() +
+ "} instead";
+ } else {
+ explanation = op.deprecation_explanation();
+ }
+ op_javadoc.add_tag("deprecated", explanation);
+ }
+ if (!op.hidden()) {
+ // expose the op in the Ops Graph API only if it is visible
+ op_class.add_annotation(
+ Annotation::Create("Operator", "org.tensorflow.op.annotation")
+ .attributes("group = \"" + endpoint.package() + "\""));
+ }
+ // create op class file
+ const string op_dir_name = io::JoinPath(
+ output_dir, str_util::StringReplace(op_class.package(), ".", "/", true));
+ if (!env->FileExists(op_dir_name).ok()) {
+ TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name))
+ << op_dir_name;
+ }
+ const string op_file_name = op_class.name() + ".java";
+ std::unique_ptr<tensorflow::WritableFile> op_file;
+ TF_CHECK_OK(
+ env->NewWritableFile(io::JoinPath(op_dir_name, op_file_name), &op_file))
+ << op_file_name;
+
+ // render endpoint source code
+ SourceFileWriter writer(op_file.get());
+ std::list<Type> dependencies;
+ CollectOpDependencies(op, mode, &dependencies);
+ writer.Write(kLicense).EndLine().BeginType(op_class, PUBLIC | FINAL,
+ &dependencies, &op_javadoc);
+ if (!op.optional_attributes().empty()) {
+ RenderOptionsClass(op, op_class, &writer);
+ }
+ RenderFactoryMethods(op, op_class, &writer);
+ RenderGettersAndSetters(op, &writer);
+ if (mode != DEFAULT) {
+ RenderInterfaceImpl(op, mode, &writer);
+ }
+ writer.EndLine();
+ for (const ArgumentSpec& output : op.outputs()) {
+ writer.WriteField(output.var(), PRIVATE);
+ }
+ RenderConstructor(op, op_class, &writer);
+ writer.EndType();
+}
+
+bool CanGenerateOp(const OpDef& op_def, const ApiDef& api_def) {
+ if (api_def.visibility() == ApiDef::SKIP) {
+ return false;
+ }
+ for (const auto& attr : op_def.attr()) {
+ if (attr.type() == "func") {
+ return false; // TODO(karllessard) add support for function attributes
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+Status OpGenerator::Run(const OpList& op_list, const string& base_package,
+ const string& output_dir) {
+ ApiDefMap api_map(op_list);
+ if (!api_dirs_.empty()) {
+ // Only load api files that correspond to the requested "op_list"
+ for (const auto& op : op_list.op()) {
+ for (const auto& api_def_dir : api_dirs_) {
+ const std::string api_def_file_pattern =
+ io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
+ if (env_->FileExists(api_def_file_pattern).ok()) {
+ TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern))
+ << api_def_file_pattern;
+ }
+ }
+ }
+ }
+ api_map.UpdateDocs();
+ for (const auto& op_def : op_list.op()) {
+ const ApiDef* api_def = api_map.GetApiDef(op_def.name());
+ if (CanGenerateOp(op_def, *api_def)) {
+ OpSpec op(OpSpec::Create(op_def, *api_def));
+ for (const EndpointSpec& endpoint : op.endpoints()) {
+ GenerateOp(op, endpoint, base_package, output_dir, env_);
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 4b55ed3ed9..05decd6b54 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -17,34 +17,39 @@ limitations under the License.
#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
#include <string>
+#include <vector>
-#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
namespace java {
-/// \brief A generator of Java operation wrappers.
-///
-/// Such generator is normally ran only once per executable, outputting
-/// wrappers for the all registered operations it has been compiled with.
-/// Nonetheless, it is designed to support multiple runs, giving a different
-/// list of operations on each cycle.
+// A generator of Java operation wrappers.
+//
+// This generator takes a list of ops definitions in input and outputs
+// a Java Op wrapper for each of them in the provided directory. The same
+// generator instance can be invoked multiple times with a different list of
+// ops definitions.
class OpGenerator {
public:
- OpGenerator();
- virtual ~OpGenerator();
+ explicit OpGenerator(const std::vector<string>& api_dirs,
+ Env* env = Env::Default())
+ : api_dirs_(api_dirs), env_(env) {}
- /// \brief Generates wrappers for the given list of 'ops'.
- ///
- /// Output files are generated in <output_dir>/<base_package>/<lib_package>,
- /// where 'lib_package' is derived from 'lib_name'.
- Status Run(const OpList& ops, const string& lib_name,
- const string& base_package, const string& output_dir);
+ // Generates wrappers for the given list of 'ops'.
+ //
+ // Output files are generated in <output_dir>/<base_package>/<op_package>,
+ // where 'op_package' is derived from ops endpoints.
+ Status Run(const OpList& op_list, const string& base_package,
+ const string& output_dir);
private:
- Env* env;
+ const std::vector<string> api_dirs_;
+ Env* env_;
};
} // namespace java
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
new file mode 100644
index 0000000000..181fd4c5e3
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -0,0 +1,423 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "re2/re2.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_specs.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+inline bool IsRealNumbers(const AttrValue& values) {
+ if (!values.has_list()) {
+ return RealNumberTypes().Contains(values.type());
+ }
+ for (int i = 0; i < values.list().type_size(); ++i) {
+ if (!RealNumberTypes().Contains(values.list().type(i))) {
+ return false;
+ }
+ }
+ return true;
+}
+
+class TypeResolver {
+ public:
+ explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {}
+
+ // Returns the class type of an input/output argument
+ //
+ // For example, if the argument's datatype is DT_STRING, this method will
+ // return "java.lang.String", so the argument can become "Operand<String>"
+ // in the Ops API
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
+
+ // Returns types of an input attribute
+ //
+ // The first element of the pair is the class type of this attribute while
+ // the second is its JNI/primitive type equivalent, required for explicit
+ // unboxing.
+ //
+ // For example, if the attribute is of type "float", this method will return
+ // <java.lang.Float, float>, so the attribute can be used as a "Float" object
+ // in the Ops API and casted to a "float" when passing through the JNI layer.
+ std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
+ bool* iterable_out);
+
+ // Returns true if the type of this attribute has already been resolved
+ bool IsAttributeVisited(const string& attr_name) {
+ return visited_attrs_.find(attr_name) != visited_attrs_.cend();
+ }
+
+ private:
+ const OpDef op_def_;
+ std::map<std::string, Type> visited_attrs_;
+ char next_generic_letter_ = 'T';
+
+ std::pair<Type, Type> MakeTypePair(const Type& type, const Type& jni_type) {
+ return std::make_pair(type, jni_type);
+ }
+ std::pair<Type, Type> MakeTypePair(const Type& type) {
+ return std::make_pair(type, type);
+ }
+ Type NextGeneric() {
+ char generic_letter = next_generic_letter_++;
+ if (next_generic_letter_ > 'Z') {
+ next_generic_letter_ = 'A';
+ }
+ return Type::Generic(string(1, generic_letter));
+ }
+};
+
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
+ *iterable_out = false;
+ if (!arg_def.number_attr().empty()) {
+ // when number_attr is set, argument has to be a list of tensors
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
+ }
+ Type type = Type::Wildcard();
+ if (arg_def.type() != DataType::DT_INVALID) {
+ // resolve type from DataType
+ switch (arg_def.type()) {
+ case DataType::DT_BOOL:
+ type = Type::Class("Boolean");
+ break;
+ case DataType::DT_STRING:
+ type = Type::Class("String");
+ break;
+ case DataType::DT_FLOAT:
+ type = Type::Class("Float");
+ break;
+ case DataType::DT_DOUBLE:
+ type = Type::Class("Double");
+ break;
+ case DataType::DT_UINT8:
+ type = Type::Class("UInt8", "org.tensorflow.types");
+ break;
+ case DataType::DT_INT32:
+ type = Type::Class("Integer");
+ break;
+ case DataType::DT_INT64:
+ type = Type::Class("Long");
+ break;
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ break;
+ default:
+ // Any other datatypes does not have a equivalent in Java and must
+ // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+ break;
+ }
+ } else if (!arg_def.type_attr().empty()) {
+ // resolve type from attribute (if already visited, retrieve its type)
+ if (IsAttributeVisited(arg_def.type_attr())) {
+ type = visited_attrs_.at(arg_def.type_attr());
+ } else {
+ for (const auto& attr_def : op_def_.attr()) {
+ if (attr_def.name() == arg_def.type_attr()) {
+ type = TypesOf(attr_def, iterable_out).first;
+ break;
+ }
+ }
+ }
+ } else if (!arg_def.type_list_attr().empty()) {
+ // type is a list of tensors that can be of different data types, so leave
+ // it as a list of wildcards
+ *iterable_out = true;
+ visited_attrs_.insert(std::make_pair(arg_def.type_list_attr(), type));
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ return type;
+}
+
+std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
+ bool* iterable_out) {
+ std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
+ *iterable_out = false;
+ StringPiece attr_type = attr_def.type();
+ if (str_util::ConsumePrefix(&attr_type, "list(")) {
+ attr_type.remove_suffix(1); // remove closing brace
+ *iterable_out = true;
+ }
+ if (attr_type == "string") {
+ types = MakeTypePair(Type::Class("String"));
+
+ } else if (attr_type == "int") {
+ types = MakeTypePair(Type::Class("Long"), Type::Long());
+
+ } else if (attr_type == "float") {
+ types = MakeTypePair(Type::Class("Float"), Type::Float());
+
+ } else if (attr_type == "bool") {
+ types = MakeTypePair(Type::Class("Boolean"), Type::Boolean());
+
+ } else if (attr_type == "shape") {
+ types = MakeTypePair(Type::Class("Shape", "org.tensorflow"));
+
+ } else if (attr_type == "tensor") {
+ types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
+ .add_parameter(Type::Wildcard()));
+
+ } else if (attr_type == "type") {
+ Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
+ if (IsRealNumbers(attr_def.allowed_values())) {
+ type.add_supertype(Type::Class("Number"));
+ }
+ types = MakeTypePair(type, Type::Enum("DataType", "org.tensorflow"));
+
+ } else {
+ LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
+ << "\" in operation \"" << op_def_.name() << "\"";
+ }
+ visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
+ return types;
+}
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+ string result;
+ bool cap = upper;
+ for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+ const char c = *it;
+ if (c == '_') {
+ cap = true;
+ } else if (cap) {
+ result += toupper(c);
+ cap = false;
+ } else {
+ result += c;
+ }
+ }
+ return result;
+}
+
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+ string* ret_match = nullptr) {
+ string match;
+ if (!RE2::PartialMatch(*input, expr, &match)) return false;
+ *before_match = input->substr(0, input->find(match));
+ *input = input->substr(before_match->size() + match.size());
+ if (ret_match != nullptr) *ret_match = match;
+ return true;
+}
+
+string ParseDocumentation(const string& inp) {
+ std::stringstream javadoc_text;
+
+ // TODO(karllessard) This is a very minimalist utility method for converting
+ // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
+ // for alternatives to increase the level of support for markups.
+ std::vector<string> markups_subexpr;
+ markups_subexpr.push_back("\n+\\*\\s+"); // lists
+ markups_subexpr.push_back("\n{2,}"); // paragraphs
+ markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
+ markups_subexpr.push_back("`+"); // inlined code and code blocks
+ markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
+ markups_subexpr.push_back("\\["); // hyperlinks
+ const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
+
+ bool in_list = false;
+ string input = inp;
+ while (true) {
+ string text, markup;
+ if (!FindAndCut(&input, markup_expr, &text, &markup)) {
+ javadoc_text << input;
+ break; // end of loop
+ }
+ javadoc_text << text;
+ if (str_util::StartsWith(markup, "\n")) {
+ javadoc_text << "\n";
+ if (str_util::StrContains(markup, "*")) {
+ // new list item
+ javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
+ in_list = true;
+ } else if (in_list) {
+ // end of list
+ javadoc_text << "</li>\n</ul>\n";
+ in_list = false;
+ } else if (!str_util::StartsWith(input, "```")) {
+ // new paragraph (not required if a <pre> block follows)
+ javadoc_text << "<p>\n";
+ }
+ } else if (str_util::StartsWith(markup, "```")) {
+ // code blocks
+ if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
+ javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (str_util::StartsWith("(" + markup + ")", "`")) {
+ // inlined code
+ if (FindAndCut(&input, markup, &text)) {
+ javadoc_text << "{@code " << text << "}";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "**") {
+ // text emphasis (strong)
+ if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
+ javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (markup == "*") {
+ // text emphasis (normal)
+ if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
+ javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else if (str_util::StartsWith(markup, "[")) {
+ // hyperlinks
+ string label;
+ string link;
+ if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+ &link) &&
+ str_util::StartsWith(input, label + link)) {
+ input = input.substr(label.size() + link.size());
+ javadoc_text << "<a href=\"" << link << "\">"
+ << ParseDocumentation(label) << "</a>";
+ } else {
+ javadoc_text << markup;
+ }
+ } else {
+ // safe fallback
+ javadoc_text << markup;
+ }
+ }
+ return javadoc_text.str();
+}
+
+ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
+ const ApiDef::Arg& input_api_def,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(input_def, &iterable);
+ Type var_type =
+ Type::Interface("Operand", "org.tensorflow").add_parameter(type);
+ if (iterable) {
+ var_type = Type::IterableOf(var_type);
+ }
+ return ArgumentSpec(
+ input_api_def.name(),
+ Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
+ type, ParseDocumentation(input_api_def.description()), iterable);
+}
+
+AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
+ const ApiDef::Attr& attr_api_def,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
+ Type var_type = types.first.kind() == Type::GENERIC
+ ? Type::Class("Class").add_parameter(types.first)
+ : types.first;
+ if (iterable) {
+ var_type = Type::ListOf(var_type);
+ }
+ return AttributeSpec(
+ attr_api_def.name(),
+ Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
+ types.first, types.second, ParseDocumentation(attr_api_def.description()),
+ iterable, attr_api_def.has_default_value());
+}
+
+ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
+ const ApiDef::Arg& output_api,
+ TypeResolver* type_resolver) {
+ bool iterable = false;
+ Type type = type_resolver->TypeOf(output_def, &iterable);
+ Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
+ if (iterable) {
+ var_type = Type::ListOf(var_type);
+ }
+ return ArgumentSpec(
+ output_api.name(),
+ Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
+ type, ParseDocumentation(output_api.description()), iterable);
+}
+
+EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
+ const ApiDef_Endpoint& endpoint_def) {
+ std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+ string package;
+ string name;
+ if (name_tokens.size() > 1) {
+ package = name_tokens.at(0);
+ name = name_tokens.at(1);
+ } else {
+ package = "core"; // generate unclassified ops in the 'core' package
+ name = name_tokens.at(0);
+ }
+ return EndpointSpec(package, name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())));
+}
+
+} // namespace
+
+OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
+ OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
+ TypeResolver type_resolver(op_def);
+ for (const string& next_input_name : api_def.arg_order()) {
+ for (int i = 0; i < op_def.input_arg().size(); ++i) {
+ if (op_def.input_arg(i).name() == next_input_name) {
+ op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
+ &type_resolver));
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < op_def.attr().size(); ++i) {
+ // do not parse attributes already visited, they have probably been inferred
+ // before as an input argument type
+ if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
+ AttributeSpec attr =
+ CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
+ // attributes with a default value are optional
+ if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
+ op.optional_attributes_.push_back(attr);
+ } else {
+ op.attributes_.push_back(attr);
+ }
+ }
+ }
+ for (int i = 0; i < op_def.output_arg().size(); ++i) {
+ op.outputs_.push_back(
+ CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
+ }
+ for (const auto& endpoint_def : api_def.endpoint()) {
+ op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
+ }
+ return op;
+}
+
+} // namespace java
+} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
new file mode 100644
index 0000000000..ca0ba16745
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -0,0 +1,173 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+
+namespace tensorflow {
+namespace java {
+
+class EndpointSpec {
+ public:
+ // A specification for an operation endpoint
+ //
+ // package: package of this endpoint (from which also derives its package)
+ // name: name of this endpoint class
+ // javadoc: the endpoint class documentation
+ // TODO(annarev): hardcode depcreated to false until deprecated is possible
+ EndpointSpec(const string& package, const string& name,
+ const Javadoc& javadoc)
+ : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {}
+
+ const string& package() const { return package_; }
+ const string& name() const { return name_; }
+ const Javadoc& javadoc() const { return javadoc_; }
+ bool deprecated() const { return deprecated_; }
+
+ private:
+ const string package_;
+ const string name_;
+ const Javadoc javadoc_;
+ const bool deprecated_;
+};
+
+class ArgumentSpec {
+ public:
+ // A specification for an operation argument
+ //
+ // op_def_name: argument name, as known by TensorFlow core
+ // var: a variable to represent this argument in Java
+ // type: the tensor type of this argument
+ // description: a description of this argument, in javadoc
+ // iterable: true if this argument is a list
+ ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type,
+ const string& description, bool iterable)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable) {}
+
+ const string& op_def_name() const { return op_def_name_; }
+ const Variable& var() const { return var_; }
+ const Type& type() const { return type_; }
+ const string& description() const { return description_; }
+ bool iterable() const { return iterable_; }
+
+ private:
+ const string op_def_name_;
+ const Variable var_;
+ const Type type_;
+ const string description_;
+ const bool iterable_;
+};
+
+class AttributeSpec {
+ public:
+ // A specification for an operation attribute
+ //
+ // op_def_name: attribute name, as known by TensorFlow core
+ // var: a variable to represent this attribute in Java
+ // type: the type of this attribute
+ // jni_type: the type of this attribute in JNI layer (see OperationBuilder)
+ // description: a description of this attribute, in javadoc
+ // iterable: true if this attribute is a list
+ // has_default_value: true if this attribute has a default value if not set
+ AttributeSpec(const string& op_def_name, const Variable& var,
+ const Type& type, const Type& jni_type,
+ const string& description, bool iterable,
+ bool has_default_value)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable),
+ jni_type_(jni_type),
+ has_default_value_(has_default_value) {}
+
+ const string& op_def_name() const { return op_def_name_; }
+ const Variable& var() const { return var_; }
+ const Type& type() const { return type_; }
+ const string& description() const { return description_; }
+ bool iterable() const { return iterable_; }
+ const Type& jni_type() const { return jni_type_; }
+ bool has_default_value() const { return has_default_value_; }
+
+ private:
+ const string op_def_name_;
+ const Variable var_;
+ const Type type_;
+ const string description_;
+ const bool iterable_;
+ const Type jni_type_;
+ const bool has_default_value_;
+};
+
+class OpSpec {
+ public:
+ // Parses an op definition and its API to produce a specification used for
+ // rendering its Java wrapper
+ //
+ // op_def: Op definition
+ // api_def: Op API definition
+ static OpSpec Create(const OpDef& op_def, const ApiDef& api_def);
+
+ const string& graph_op_name() const { return graph_op_name_; }
+ bool hidden() const { return hidden_; }
+ const string& deprecation_explanation() const {
+ return deprecation_explanation_;
+ }
+ const std::vector<EndpointSpec> endpoints() const { return endpoints_; }
+ const std::vector<ArgumentSpec>& inputs() const { return inputs_; }
+ const std::vector<ArgumentSpec>& outputs() const { return outputs_; }
+ const std::vector<AttributeSpec>& attributes() const { return attributes_; }
+ const std::vector<AttributeSpec>& optional_attributes() const {
+ return optional_attributes_;
+ }
+
+ private:
+ // A specification for an operation
+ //
+ // graph_op_name: name of this op, as known by TensorFlow core engine
+ // hidden: true if this op should not be visible through the Graph Ops API
+ // deprecation_explanation: message to show if all endpoints are deprecated
+ explicit OpSpec(const string& graph_op_name, bool hidden,
+ const string& deprecation_explanation)
+ : graph_op_name_(graph_op_name),
+ hidden_(hidden),
+ deprecation_explanation_(deprecation_explanation) {}
+
+ const string graph_op_name_;
+ const bool hidden_;
+ const string deprecation_explanation_;
+ std::vector<EndpointSpec> endpoints_;
+ std::vector<ArgumentSpec> inputs_;
+ std::vector<ArgumentSpec> outputs_;
+ std::vector<AttributeSpec> attributes_;
+ std::vector<AttributeSpec> optional_attributes_;
+};
+
+} // namespace java
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_SPECS_H_
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index a02f75ad6e..66401bdba7 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <string>
#include <algorithm>
-#include <deque>
+#include <list>
+#include <string>
#include "tensorflow/java/src/gen/cc/source_writer.h"
@@ -83,20 +83,22 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
}
SourceWriter& SourceWriter::AppendType(const Type& type) {
- if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
+ if (type.wildcard()) {
Append("?");
} else {
Append(type.name());
- }
- if (!type.parameters().empty()) {
- Append("<");
- for (const Type& t : type.parameters()) {
- if (&t != &type.parameters().front()) {
- Append(", ");
+ if (!type.parameters().empty()) {
+ Append("<");
+ bool first = true;
+ for (const Type& t : type.parameters()) {
+ if (!first) {
+ Append(", ");
+ }
+ AppendType(t);
+ first = false;
}
- AppendType(t);
+ Append(">");
}
- Append(">");
}
return *this;
}
@@ -107,7 +109,21 @@ SourceWriter& SourceWriter::EndLine() {
return *this;
}
-SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
+SourceWriter& SourceWriter::BeginBlock(const string& expression) {
+ if (!expression.empty()) {
+ Append(expression + " {");
+ } else {
+ Append(newline_ ? "{" : " {");
+ }
+ return EndLine().Indent(2);
+}
+
+SourceWriter& SourceWriter::EndBlock() {
+ return Indent(-2).Append("}").EndLine();
+}
+
+SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
if (!method.constructor()) {
generic_namespace->Visit(method.return_type());
@@ -116,8 +132,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
generic_namespace->Visit(v.type());
}
EndLine();
- WriteDoc(method.description(), method.return_description(),
- &method.arguments());
+ if (javadoc != nullptr) {
+ WriteJavadoc(*javadoc);
+ }
if (!method.annotations().empty()) {
WriteAnnotations(method.annotations());
}
@@ -130,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
AppendType(method.return_type()).Append(" ");
}
Append(method.name()).Append("(");
+ bool first = true;
for (const Variable& v : method.arguments()) {
- if (&v != &method.arguments().front()) {
+ if (!first) {
Append(", ");
}
AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name());
+ first = false;
}
return Append(")").BeginBlock();
}
@@ -145,29 +164,36 @@ SourceWriter& SourceWriter::EndMethod() {
return *this;
}
-SourceWriter& SourceWriter::BeginType(const Type& type,
- const std::list<Type>* dependencies, int modifiers) {
+SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
+ const std::list<Type>* extra_dependencies,
+ const Javadoc* javadoc) {
if (!type.package().empty()) {
Append("package ").Append(type.package()).Append(";").EndLine();
}
- if (dependencies != nullptr && !dependencies->empty()) {
- TypeImporter type_importer(type.package());
- for (const Type& t : *dependencies) {
+ TypeImporter type_importer(type.package());
+ type_importer.Visit(type);
+ if (extra_dependencies != nullptr) {
+ for (const Type& t : *extra_dependencies) {
type_importer.Visit(t);
}
+ }
+ if (!type_importer.imports().empty()) {
EndLine();
for (const string& s : type_importer.imports()) {
Append("import ").Append(s).Append(";").EndLine();
}
}
- return BeginInnerType(type, modifiers);
+ return BeginInnerType(type, modifiers, javadoc);
}
-SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
+SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
generic_namespace->Visit(type);
EndLine();
- WriteDoc(type.description());
+ if (javadoc != nullptr) {
+ WriteJavadoc(*javadoc);
+ }
if (!type.annotations().empty()) {
WriteAnnotations(type.annotations());
}
@@ -200,14 +226,15 @@ SourceWriter& SourceWriter::EndType() {
return *this;
}
-SourceWriter& SourceWriter::WriteFields(const std::list<Variable>& fields,
- int modifiers) {
- EndLine();
- for (const Variable& v : fields) {
- WriteModifiers(modifiers);
- AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
- EndLine();
+SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
+ const Javadoc* javadoc) {
+ // If present, write field javadoc only as one brief line
+ if (javadoc != nullptr && !javadoc->brief().empty()) {
+ Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
}
+ WriteModifiers(modifiers);
+ AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
+ EndLine();
return *this;
}
@@ -228,39 +255,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
return *this;
}
-SourceWriter& SourceWriter::WriteDoc(const string& description,
- const string& return_description, const std::list<Variable>* parameters) {
- if (description.empty() && return_description.empty()
- && (parameters == nullptr || parameters->empty())) {
- return *this; // no doc to write
- }
+SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
+ Append("/**").Prefix(" * ").EndLine();
bool do_line_break = false;
- Append("/**").EndLine().Prefix(" * ");
- if (!description.empty()) {
- Write(description).EndLine();
+ if (!javadoc.brief().empty()) {
+ Write(javadoc.brief()).EndLine();
do_line_break = true;
}
- if (parameters != nullptr && !parameters->empty()) {
+ if (!javadoc.details().empty()) {
if (do_line_break) {
- EndLine();
- do_line_break = false;
- }
- for (const Variable& v : *parameters) {
- Append("@param ").Append(v.name());
- if (!v.description().empty()) {
- Append(" ").Write(v.description());
- }
- EndLine();
+ Append("<p>").EndLine();
}
+ Write(javadoc.details()).EndLine();
+ do_line_break = true;
}
- if (!return_description.empty()) {
+ if (!javadoc.tags().empty()) {
if (do_line_break) {
EndLine();
- do_line_break = false;
}
- Append("@return ").Write(return_description).EndLine();
+ for (const auto& p : javadoc.tags()) {
+ Append("@" + p.first);
+ if (!p.second.empty()) {
+ Append(" ").Write(p.second);
+ }
+ EndLine();
+ }
}
- return Prefix("").Append(" **/").EndLine();
+ return Prefix("").Append(" */").EndLine();
}
SourceWriter& SourceWriter::WriteAnnotations(
@@ -278,14 +299,16 @@ SourceWriter& SourceWriter::WriteAnnotations(
SourceWriter& SourceWriter::WriteGenerics(
const std::list<const Type*>& generics) {
Append("<");
+ bool first = true;
for (const Type* pt : generics) {
- if (pt != generics.front()) {
+ if (!first) {
Append(", ");
}
Append(pt->name());
if (!pt->supertypes().empty()) {
Append(" extends ").AppendType(pt->supertypes().front());
}
+ first = false;
}
return Append(">");
}
@@ -311,21 +334,20 @@ void SourceWriter::PopGenericNamespace() {
void SourceWriter::TypeVisitor::Visit(const Type& type) {
DoVisit(type);
for (const Type& t : type.parameters()) {
- DoVisit(t);
+ Visit(t);
}
for (const Annotation& t : type.annotations()) {
DoVisit(t);
}
for (const Type& t : type.supertypes()) {
- DoVisit(t);
+ Visit(t);
}
}
void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
// ignore non-generic parameters, wildcards and generics already declared
- if (type.kind() == Type::GENERIC
- && !type.IsWildcard()
- && generic_names_.find(type.name()) == generic_names_.end()) {
+ if (type.kind() == Type::GENERIC && !type.wildcard() &&
+ generic_names_.find(type.name()) == generic_names_.end()) {
declared_types_.push_back(&type);
generic_names_.insert(type.name());
}
@@ -333,7 +355,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
void SourceWriter::TypeImporter::DoVisit(const Type& type) {
if (!type.package().empty() && type.package() != current_package_) {
- imports_.insert(type.package() + '.' + type.name());
+ imports_.insert(type.canonical_name());
}
}
diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h
index f011acd30a..de0113bd5b 100644
--- a/tensorflow/java/src/gen/cc/source_writer.h
+++ b/tensorflow/java/src/gen/cc/source_writer.h
@@ -93,25 +93,22 @@ class SourceWriter {
// This method appends a new opening brace to the current data and indent the
// next lines according to Google Java Style Guide. The block can optionally
// be preceded by an expression (e.g. Append("if(true)").BeginBlock();)
- SourceWriter& BeginBlock() {
- return Append(newline_ ? "{" : " {").EndLine().Indent(2);
- }
+ SourceWriter& BeginBlock(const string& expression = "");
// Ends the current block of source code.
//
// This method appends a new closing brace to the current data and outdent the
// next lines back to the margin used before BeginBlock() was invoked.
- SourceWriter& EndBlock() {
- return Indent(-2).Append("}").EndLine();
- }
+ SourceWriter& EndBlock();
// Begins to write a method.
//
// This method outputs the signature of the Java method from the data passed
- // in the 'method' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // method.
- SourceWriter& BeginMethod(const Method& method, int modifiers = 0);
+ // in the 'method' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the access scope of this method and, optionally,
+ // a Javadoc.
+ SourceWriter& BeginMethod(const Method& method, int modifiers,
+ const Javadoc* javadoc = nullptr);
// Ends the current method.
//
@@ -122,22 +119,24 @@ class SourceWriter {
// Begins to write the main type of a source file.
//
// This method outputs the declaration of the Java type from the data passed
- // in the 'type' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // type.
+ // in the 'type' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the access scope of this type and, optionally,
+ // a Javadoc.
//
- // If not null, all types found in the 'dependencies' list will be imported
- // before declaring the new type.
- SourceWriter& BeginType(const Type& clazz,
- const std::list<Type>* dependencies, int modifiers = 0);
+ // If not null, all types found in the 'extra_dependencies' list will be
+ // imported before declaring the new type.
+ SourceWriter& BeginType(const Type& type, int modifiers,
+ const std::list<Type>* extra_dependencies = nullptr,
+ const Javadoc* javadoc = nullptr);
// Begins to write a new inner type.
//
// This method outputs the declaration of the Java type from the data passed
- // in the 'type' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // type.
- SourceWriter& BeginInnerType(const Type& type, int modifiers = 0);
+ // in the 'type' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the accesses and the scope of this type and,
+ // optionally, a Javadoc.
+ SourceWriter& BeginInnerType(const Type& type, int modifiers,
+ const Javadoc* javadoc = nullptr);
// Ends the current type.
//
@@ -145,13 +144,13 @@ class SourceWriter {
// BeginType() or BeginInnerType() prior to this.
SourceWriter& EndType();
- // Writes a list of variables as fields of a type.
+ // Writes a variable as fields of a type.
//
// This method must be called within the definition of a type (see BeginType()
- // or BeginInnerType()). Additional modifiers can also be passed in parameter
- // to define the accesses and the scope of those fields.
- SourceWriter& WriteFields(const std::list<Variable>& fields,
- int modifiers = 0);
+ // or BeginInnerType()). Modifiers are also be passed in parameter to define
+ // the accesses and the scope of this field and, optionally, a Javadoc.
+ SourceWriter& WriteField(const Variable& field, int modifiers,
+ const Javadoc* javadoc = nullptr);
protected:
virtual void DoAppend(const StringPiece& str) = 0;
@@ -207,9 +206,7 @@ class SourceWriter {
std::stack<GenericNamespace*> generic_namespaces_;
SourceWriter& WriteModifiers(int modifiers);
- SourceWriter& WriteDoc(const string& description,
- const string& return_description = "",
- const std::list<Variable>* parameters = nullptr);
+ SourceWriter& WriteJavadoc(const Javadoc& javadoc);
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
GenericNamespace* PushGenericNamespace(int modifiers);
diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc
index 4bce2fea70..fb8fc64dff 100644
--- a/tensorflow/java/src/gen/cc/source_writer_test.cc
+++ b/tensorflow/java/src/gen/cc/source_writer_test.cc
@@ -245,12 +245,17 @@ TEST(StreamTest, Types) {
SourceBufferWriter writer;
Type generic = Type::Generic("T").add_supertype(Type::Class("Number"));
- writer.AppendType(Type::Int()).Append(", ")
- .AppendType(Type::Class("String")).Append(", ")
- .AppendType(generic).Append(", ")
- .AppendType(Type::ListOf(generic)).Append(", ")
- .AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ")
- .AppendType(Type::ListOf(Type::Generic()));
+ writer.AppendType(Type::Int())
+ .Append(", ")
+ .AppendType(Type::Class("String"))
+ .Append(", ")
+ .AppendType(generic)
+ .Append(", ")
+ .AppendType(Type::ListOf(generic))
+ .Append(", ")
+ .AppendType(Type::ListOf(Type::IterableOf(generic)))
+ .Append(", ")
+ .AppendType(Type::ListOf(Type::Wildcard()));
const char* expected =
"int, String, T, List<T>, List<Iterable<T>>, List<?>";
@@ -282,7 +287,7 @@ TEST(WriteType, SimpleClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -300,7 +305,7 @@ TEST(WriteType, SimpleClassWithDependencies) {
deps.push_back(Type::Class("SamePackageType", "org.tensorflow"));
deps.push_back(Type::Class("NoPackageType"));
- writer.BeginType(clazz, &deps, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC, &deps).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -313,20 +318,22 @@ TEST(WriteType, SimpleClassWithDependencies) {
TEST(WriteType, AnnotatedAndDocumentedClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
- clazz.description("This class has a\n<p>\nmultiline description.");
+ Javadoc clazz_doc = Javadoc::Create("Javadoc test")
+ .details("This is a\nmultiline description.");
clazz.add_annotation(Annotation::Create("Bean"));
clazz.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType();
const char* expected =
"package org.tensorflow;\n\n"
"/**\n"
- " * This class has a\n"
+ " * Javadoc test\n"
" * <p>\n"
+ " * This is a\n"
" * multiline description.\n"
- " **/\n"
+ " */\n"
"@Bean\n"
"@SuppressWarnings(\"rawtypes\")\n"
"public class Test {\n}\n";
@@ -339,7 +346,7 @@ TEST(WriteType, ParameterizedClass) {
clazz.add_parameter(Type::Generic("T"));
clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number")));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -358,7 +365,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) {
clazz.add_supertype(Type::Interface("Runnable"));
clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -372,24 +379,23 @@ TEST(WriteType, ParameterizedClassFields) {
Type clazz = Type::Class("Test", "org.tensorflow");
Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
clazz.add_parameter(type_t);
- std::list<Variable> static_fields;
- static_fields.push_back(Variable::Create("field1", Type::Class("String")));
- std::list<Variable> member_fields;
- member_fields.push_back(Variable::Create("field2", Type::Class("String")));
- member_fields.push_back(Variable::Create("field3", type_t));
+ Variable field1 = Variable::Create("field1", Type::Class("String"));
+ Variable field2 = Variable::Create("field2", Type::Class("String"));
+ Variable field3 = Variable::Create("field3", type_t);
+ Javadoc field3_doc = Javadoc::Create("This variable is documented");
- writer.BeginType(clazz, nullptr, PUBLIC)
- .WriteFields(static_fields, STATIC | PUBLIC | FINAL)
- .WriteFields(member_fields, PRIVATE)
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .WriteField(field1, STATIC | PUBLIC | FINAL)
+ .WriteField(field2, PRIVATE)
+ .WriteField(field3, PRIVATE, &field3_doc)
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
"public class Test<T extends Number> {\n"
- " \n"
" public static final String field1;\n"
- " \n"
" private String field2;\n"
+ " /** This variable is documented */\n"
" private T field3;\n"
"}\n";
ASSERT_STREQ(expected, writer.str().data());
@@ -400,10 +406,10 @@ TEST(WriteType, SimpleInnerClass) {
Type clazz = Type::Class("Test", "org.tensorflow");
Type inner_class = Type::Class("InnerTest");
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginInnerType(inner_class, PUBLIC)
- .EndType()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginInnerType(inner_class, PUBLIC)
+ .EndType()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -423,10 +429,10 @@ TEST(WriteType, StaticParameterizedInnerClass) {
Type inner_class = Type::Class("InnerTest");
inner_class.add_parameter(type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginInnerType(inner_class, PUBLIC | STATIC)
- .EndType()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginInnerType(inner_class, PUBLIC | STATIC)
+ .EndType()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -443,9 +449,10 @@ TEST(WriteMethod, SimpleMethod) {
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC).EndMethod()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC)
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -461,24 +468,28 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- method.description("This method has a\n<p>\nmultiline description.");
+ Javadoc method_doc =
+ Javadoc::Create("Javadoc test")
+ .details("This method has a\nmultiline description.");
method.add_annotation(Annotation::Create("Override"));
method.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC).EndMethod()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC, &method_doc)
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
"public class Test {\n"
" \n"
" /**\n"
- " * This method has a\n"
+ " * Javadoc test\n"
" * <p>\n"
+ " * This method has a\n"
" * multiline description.\n"
- " **/\n"
+ " */\n"
" @Override\n"
" @SuppressWarnings(\"rawtypes\")\n"
" public void doNothing() {\n"
@@ -490,23 +501,27 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
TEST(WriteMethod, DocumentedMethodWithArguments) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
+ Variable reverse = Variable::Create("reverse", Type::Boolean());
Method method = Method::Create("boolToInt", Type::Int());
- method.description("Converts a boolean to an int");
- method.return_description("int value for this boolean");
method.add_argument(Variable::Create("b", Type::Boolean()));
- Variable reverse = Variable::Create("reverse", Type::Boolean());
- reverse.description("if true, value is reversed");
method.add_argument(reverse);
-
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC)
- .Append("if (b && !reverse)")
- .BeginBlock()
- .Append("return 1;").EndLine()
- .EndBlock()
- .Append("return 0;").EndLine()
- .EndMethod()
- .EndType();
+ Javadoc method_doc =
+ Javadoc::Create("Converts a boolean to an int")
+ .details("This method will convert\na boolean to an int")
+ .add_param_tag(reverse.name(), "if true, value is reversed")
+ .add_tag("return", "int value for this boolean");
+
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC, &method_doc)
+ .Append("if (b && !reverse)")
+ .BeginBlock()
+ .Append("return 1;")
+ .EndLine()
+ .EndBlock()
+ .Append("return 0;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -514,11 +529,13 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
" \n"
" /**\n"
" * Converts a boolean to an int\n"
+ " * <p>\n"
+ " * This method will convert\n"
+ " * a boolean to an int\n"
" * \n"
- " * @param b\n"
" * @param reverse if true, value is reversed\n"
" * @return int value for this boolean\n"
- " **/\n"
+ " */\n"
" public int boolToInt(boolean b, boolean reverse) {\n"
" if (b && !reverse) {\n"
" return 1;\n"
@@ -536,11 +553,12 @@ TEST(WriteMethod, ParameterizedMethod) {
clazz.add_parameter(type_t);
Method method = Method::Create("doNothing", type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC)
- .Append("return null;").EndLine()
- .EndMethod()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC)
+ .Append("return null;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -560,11 +578,12 @@ TEST(WriteMethod, StaticParameterizedMethod) {
clazz.add_parameter(type_t);
Method method = Method::Create("doNothing", type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC | STATIC)
- .Append("return null;").EndLine()
- .EndMethod()
- .EndType();
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC | STATIC)
+ .Append("return null;")
+ .EndLine()
+ .EndMethod()
+ .EndType();
const char* expected =
"package org.tensorflow;\n\n"
diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl
index a6650fc4ea..f4ff34ea03 100644
--- a/tensorflow/java/src/gen/gen_ops.bzl
+++ b/tensorflow/java/src/gen/gen_ops.bzl
@@ -1,62 +1,62 @@
# -*- Python -*-
-load("//tensorflow:tensorflow.bzl",
- "tf_binary_additional_srcs",
- "tf_cc_binary",
- "tf_copts")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_binary_additional_srcs",
+)
-# Given a list of "ops_libs" (a list of files in the core/ops directory
-# without their .cc extensions), generate Java wrapper code for all operations
-# found in the ops files.
-# Then, combine all those source files into a single archive (.srcjar).
+# Generate Java wrapper classes for all registered core operations and package
+# them into a single source archive (.srcjar).
#
# For example:
-# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ])
+# tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package")
#
-# will create a genrule named "gen_sources" that first generate source files:
-# ops/src/main/java/my/package/array/*.java
-# ops/src/main/java/my/package/math/*.java
+# will create a genrule named "gen_sources" that generates source files under
+# ops/src/main/java/my/package/**/*.java
#
-# and then archive those source files in:
+# and then archive those source files into
# ops/gen_sources.srcjar
#
def tf_java_op_gen_srcjar(name,
gen_tool,
- gen_base_package,
- ops_libs=[],
- ops_libs_pkg="//tensorflow/core",
+ base_package,
+ api_def_srcs=[],
out_dir="ops/",
out_src_dir="src/main/java/",
visibility=["//tensorflow/java:__pkg__"]):
- gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
+ srcs = api_def_srcs[:]
- # Construct an op generator binary for each ops library.
- for ops_lib in ops_libs:
- gen_lib = ops_lib[:ops_lib.rfind("_")]
- out_gen_tool = out_dir + ops_lib + "_gen_tool"
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ "$$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))")
+ api_def_args_str = ",".join(api_def_args)
- tf_cc_binary(
- name=out_gen_tool,
- copts=tf_copts(),
- linkopts=["-lm"],
- linkstatic=1, # Faster to link this one-time-use binary dynamically
- deps=[gen_tool, ops_libs_pkg + ":" + ops_lib + "_op_lib"])
-
- gen_tools += [":" + out_gen_tool]
- gen_cmds += ["$(location :" + out_gen_tool + ")" +
- " --output_dir=$(@D)/" + out_src_dir +
- " --lib_name=" + gen_lib +
- " --base_package=" + gen_base_package]
+ gen_cmds += ["$(location " + gen_tool + ")" +
+ " --output_dir=$(@D)/" + out_src_dir +
+ " --base_package=" + base_package +
+ " --api_dirs=" + api_def_args_str]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
gen_cmds += ["$(location @local_jdk//:jar) cMf $(location :" + gen_srcjar + ") -C $(@D) src"]
- gen_tools += ["@local_jdk//:jar"] + ["@local_jdk//:jdk"]
- gen_tools += tf_binary_additional_srcs()
+
native.genrule(
name=name,
+ srcs=srcs,
outs=[gen_srcjar],
- tools=gen_tools,
- cmd="&&".join(gen_cmds))
+ tools=[
+ "@local_jdk//:jar",
+ "@local_jdk//:jdk",
+ gen_tool
+ ] + tf_binary_additional_srcs(),
+ cmd=" && ".join(gen_cmds))