aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Karl Lessard <karl@kubx.ca>2018-03-29 05:28:16 +0800
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-05-03 22:40:42 -0400
commitfe9b2637cfe39cf11eb3d0494948a733b7fc1d7d (patch)
tree280fa51b3d80c9b9f4b0e8e98f55831a384f0851 /tensorflow/java
parent2dc7575123ffa0e6413fc3d2700968ef25f049de (diff)
Parse op definition and generate a Java Op class.
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD4
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h76
-rw-r--r--tensorflow/java/src/gen/cc/op_gen_main.cc22
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc406
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h42
-rw-r--r--tensorflow/java/src/gen/cc/op_parser.cc417
-rw-r--r--tensorflow/java/src/gen/cc/op_parser.h137
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc127
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.h55
-rw-r--r--tensorflow/java/src/gen/cc/source_writer_test.cc82
-rw-r--r--tensorflow/java/src/gen/gen_ops.bzl29
-rw-r--r--tensorflow/java/src/gen/resources/license.snippet.java14
12 files changed, 1201 insertions, 210 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index ab7d698a45..635a4e807d 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -70,6 +70,7 @@ filegroup(
tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
+ api_def_srcs = ["//tensorflow/core/api_def:base_api_def"],
gen_base_package = "org.tensorflow.op",
gen_tool = "java_op_gen_tool",
ops_libs = [
@@ -111,11 +112,13 @@ cc_library(
name = "java_op_gen_lib",
srcs = [
"src/gen/cc/op_generator.cc",
+ "src/gen/cc/op_parser.cc",
"src/gen/cc/source_writer.cc",
],
hdrs = [
"src/gen/cc/java_defs.h",
"src/gen/cc/op_generator.h",
+ "src/gen/cc/op_parser.h",
"src/gen/cc/source_writer.h",
],
copts = tf_copts(),
@@ -124,6 +127,7 @@ cc_library(
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "//tensorflow/core:op_gen_lib",
],
)
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index 59f8beaee7..2065477f58 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -18,12 +18,15 @@ limitations under the License.
#include <string>
#include <list>
+#include <map>
+#include <utility>
namespace tensorflow {
namespace java {
// An enumeration of different modifiers commonly used in Java
enum Modifier {
+ PACKAGE = 0,
PUBLIC = (1 << 0),
PROTECTED = (1 << 1),
PRIVATE = (1 << 2),
@@ -72,6 +75,12 @@ class Type {
// Reflection API does
return Type(Type::PRIMITIVE, "void");
}
+ static Type Generic(const string& name) {
+ return Type(Type::GENERIC, name);
+ }
+ static Type Wildcard() {
+ return Type(Type::GENERIC, "");
+ }
static Type Class(const string& name, const string& package = "") {
return Type(Type::CLASS, name, package);
}
@@ -81,9 +90,6 @@ class Type {
static Type Enum(const string& name, const string& package = "") {
return Type(Type::ENUM, name, package);
}
- static Type Generic(const string& name = "") {
- return Type(Type::GENERIC, name);
- }
static Type ClassOf(const Type& type) {
return Class("Class").add_parameter(type);
}
@@ -96,11 +102,10 @@ class Type {
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
- const string& description() const { return description_; }
- Type& description(const string& description) {
- description_ = description;
- return *this;
+ const string full_name() const {
+ return package_.empty() ? name_ : package_ + "." + name_;
}
+ bool unknown() const { return name_.empty(); } // only wildcards has no name
const std::list<Type>& parameters() const { return parameters_; }
Type& add_parameter(const Type& parameter) {
parameters_.push_back(parameter);
@@ -120,14 +125,6 @@ class Type {
}
return *this;
}
- // Returns true if "type" is of a known collection type (only a few for now)
- bool IsCollection() const {
- return name_ == "List" || name_ == "Iterable";
- }
- // Returns true if this instance is a wildcard (<?>)
- bool IsWildcard() const {
- return kind_ == GENERIC && name_.empty();
- }
protected:
Type(Kind kind, const string& name, const string& package = "")
@@ -137,7 +134,6 @@ class Type {
Kind kind_;
string name_;
string package_;
- string description_;
std::list<Type> parameters_;
std::list<Annotation> annotations_;
std::list<Type> supertypes_;
@@ -180,16 +176,11 @@ class Variable {
const string& name() const { return name_; }
const Type& type() const { return type_; }
bool variadic() const { return variadic_; }
- const string& description() const { return description_; }
- Variable& description(const string& description) {
- description_ = description;
- return *this;
- }
+
private:
string name_;
Type type_;
bool variadic_;
- string description_;
Variable(const string& name, const Type& type, bool variadic)
: name_(name), type_(type), variadic_(variadic) {}
@@ -210,16 +201,6 @@ class Method {
bool constructor() const { return constructor_; }
const string& name() const { return name_; }
const Type& return_type() const { return return_type_; }
- const string& description() const { return description_; }
- Method& description(const string& description) {
- description_ = description;
- return *this;
- }
- const string& return_description() const { return return_description_; }
- Method& return_description(const string& description) {
- return_description_ = description;
- return *this;
- }
const std::list<Variable>& arguments() const { return arguments_; }
Method& add_argument(const Variable& var) {
arguments_.push_back(var);
@@ -235,8 +216,6 @@ class Method {
string name_;
Type return_type_;
bool constructor_;
- string description_;
- string return_description_;
std::list<Variable> arguments_;
std::list<Annotation> annotations_;
@@ -244,6 +223,35 @@ class Method {
: name_(name), return_type_(return_type), constructor_(constructor) {}
};
+// A definition of a documentation bloc for a Java element (JavaDoc)
+class Javadoc {
+ public:
+ static Javadoc Create(const string& brief = "") {
+ return Javadoc(brief);
+ }
+ const string& brief() const { return brief_; }
+ const string& details() const { return description_; }
+ Javadoc& details(const string description) {
+ description_ = description;
+ return *this;
+ }
+ const std::list<std::pair<string, string>> tags() const { return tags_; }
+ Javadoc& add_tag(const string& tag, const string& text) {
+ tags_.push_back(std::make_pair(tag, text));
+ return *this;
+ }
+ Javadoc& add_param_tag(const string& name, const string& text) {
+ return add_tag("param", name + " " + text);
+ }
+
+ private:
+ string brief_;
+ string description_;
+ std::list<std::pair<string, string>> tags_;
+
+ explicit Javadoc(const string& brief) : brief_(brief) {}
+};
+
} // namespace java
} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc
index bea99f3d7f..015200023f 100644
--- a/tensorflow/java/src/gen/cc/op_gen_main.cc
+++ b/tensorflow/java/src/gen/cc/op_gen_main.cc
@@ -48,8 +48,11 @@ const char kUsageHeader[] =
"through\n"
"the 'org.tensorflow.op.Ops' API as a group until the generated classes "
"are compiled using an appropriate annotation processor.\n\n"
- "Finally, the '--base_package' overrides the default parent package "
- "under which the generated subpackage and classes are to be located.\n\n";
+ "The '--base_package' overrides the default parent package under which "
+ "the generated subpackage and classes are to be located.\n\n"
+ "Finally, a list of directories of API proto definitions can be provided "
+ "to override default values found in the ops definitions, ordered by\n"
+ "priority (the last having precedence over the first).\n\n";
} // namespace java
} // namespace tensorflow
@@ -60,7 +63,7 @@ int main(int argc, char* argv[]) {
tensorflow::string base_package = "org.tensorflow.op";
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("output_dir", &output_dir,
- "Root directory into which output files are generated"),
+ "Root directory into which output files are generated"),
tensorflow::Flag(
"lib_name", &lib_name,
"A name, in snake_case, used to classify this set of operations"),
@@ -72,12 +75,15 @@ int main(int argc, char* argv[]) {
bool parsed_flags_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parsed_flags_ok && !lib_name.empty() && !output_dir.empty()) << usage;
-
- tensorflow::java::OpGenerator generator;
+ std::vector<tensorflow::string> api_dirs;
+ if (argc > 1) {
+ api_dirs = tensorflow::str_util::Split(argv[1], ",",
+ tensorflow::str_util::SkipEmpty());
+ }
+ tensorflow::java::OpGenerator generator(base_package, output_dir, api_dirs);
tensorflow::OpList ops;
- tensorflow::OpRegistry::Global()->Export(true, &ops);
- tensorflow::Status status =
- generator.Run(ops, lib_name, base_package, output_dir);
+ tensorflow::OpRegistry::Global()->Export(false, &ops);
+ tensorflow::Status status = generator.Run(ops, lib_name);
TF_QCHECK_OK(status);
return 0;
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index def06baf2d..c9b57f5706 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -14,53 +14,409 @@ limitations under the License.
==============================================================================*/
#include <string>
+#include <map>
+#include <vector>
+#include <list>
+#include <memory>
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+#include "tensorflow/java/src/gen/cc/source_writer.h"
+#include "tensorflow/java/src/gen/cc/op_parser.h"
#include "tensorflow/java/src/gen/cc/op_generator.h"
namespace tensorflow {
namespace java {
namespace {
-string CamelCase(const string& str, char delimiter, bool upper) {
- string result;
- bool cap = upper;
- for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
- const char c = *it;
- if (c == delimiter) {
- cap = true;
- } else if (cap) {
- result += toupper(c);
- cap = false;
+const char* kLicenseSnippet =
+ "tensorflow/java/src/gen/resources/license.snippet.java";
+
+const std::map<string, Type> kPrimitiveAttrTypes = {
+ { "Boolean", Type::Boolean() },
+ { "Byte", Type::Byte() },
+ { "Character", Type::Byte() },
+ { "Float", Type::Float() },
+ { "Integer", Type::Long() },
+ { "Long", Type::Long() },
+ { "Short", Type::Long() },
+ { "Double", Type::Float() },
+};
+
+enum RenderMode {
+ DEFAULT,
+ SINGLE_OUTPUT,
+ SINGLE_LIST_OUTPUT
+};
+
+void CollectOpDependencies(const OpSpec& op, RenderMode mode,
+ std::list<Type>* out) {
+ out->push_back(Type::Class("Operation", "org.tensorflow"));
+ out->push_back(Type::Class("OperationBuilder", "org.tensorflow"));
+ out->push_back(Type::Class("Scope", "org.tensorflow.op"));
+ if (mode == SINGLE_OUTPUT) {
+ out->push_back(Type::Class("Output", "org.tensorflow"));
+ } else if (mode == SINGLE_LIST_OUTPUT) {
+ out->push_back(Type::Interface("Iterator", "java.util"));
+ }
+ // Don't pay attention to duplicate types in the dependency list, they will
+ // be filtered out by the SourceWriter.
+ for (const OpSpec::Operand& input : op.inputs()) {
+ out->push_back(input.var().type());
+ if (input.iterable()) {
+ out->push_back(Type::Class("Operands", "org.tensorflow.op"));
+ }
+ }
+ for (const OpSpec::Operand& output : op.outputs()) {
+ out->push_back(output.var().type());
+ if (output.iterable()) {
+ out->push_back(Type::Class("Arrays", "java.util"));
+ }
+ }
+ for (const OpSpec::Operand& attribute : op.attributes()) {
+ out->push_back(attribute.var().type());
+ if (attribute.var().type().name() == "Class") {
+ out->push_back(Type::Enum("DataType", "org.tensorflow"));
+ }
+ }
+ for (const OpSpec::Operand& option : op.options()) {
+ out->push_back(option.var().type());
+ }
+}
+
+void WriteSetAttrDirective(const OpSpec::Operand& attr, bool optional,
+ SourceWriter* writer) {
+ string var = optional ? "opts." + attr.var().name() : attr.var().name();
+ if (attr.iterable()) {
+ const Type& type = attr.data_type();
+ std::map<string, Type>::const_iterator it =
+ kPrimitiveAttrTypes.find(type.name());
+ if (it != kPrimitiveAttrTypes.end()) {
+ string array = attr.var().name() + "Array";
+ writer->AppendType(it->second)
+ .Append("[] " + array + " = new ")
+ .AppendType(it->second)
+ .Append("[" + var + ".size()];")
+ .EndLine();
+ writer->BeginBlock("for (int i = 0; i < " + array + ".length; ++i)")
+ .Append(array + "[i] = " + var + ".get(i);")
+ .EndLine()
+ .EndBlock()
+ .Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + array)
+ .Append(");")
+ .EndLine();
} else {
- result += c;
+ writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", " + var)
+ .Append(".toArray(new ")
+ .AppendType(type)
+ .Append("[" + var + ".size()]));")
+ .EndLine();
}
+ } else {
+ Type type = attr.var().type();
+ writer->Append("opBuilder.setAttr(\"" + attr.graph_name() + "\", ");
+ if (type.name() == "Class") {
+ writer->Append("DataType.fromClass(" + attr.var().name() + "));");
+ } else {
+ writer->Append(var + ");");
+ }
+ writer->EndLine();
}
- return result;
}
-} // namespace
+void RenderFactoryMethod(const OpSpec& op, const Type& op_class,
+ SourceWriter* writer) {
+ Method factory = Method::Create("create", op_class);
+ Javadoc factory_doc = Javadoc::Create(
+ "Factory method to create a class to wrap a new " + op_class.name()
+ + " operation to the graph.");
+ Variable scope =
+ Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
+ factory.add_argument(scope);
+ factory_doc.add_param_tag(scope.name(), "Current graph scope");
+ for (const OpSpec::Operand& input : op.inputs()) {
+ factory.add_argument(input.var());
+ factory_doc.add_param_tag(input.var().name(), input.description());
+ }
+ for (const OpSpec::Operand& attribute : op.attributes()) {
+ factory.add_argument(attribute.var());
+ factory_doc.add_param_tag(attribute.var().name(), attribute.description());
+ }
+ if (!op.options().empty()) {
+ factory.add_argument(Variable::Varargs("options", Type::Class("Options")));
+ factory_doc.add_param_tag("options", "carries optional attributes values");
+ }
+ factory_doc.add_tag("return", "a new instance of " + op_class.name());
+ writer->BeginMethod(factory, PUBLIC|STATIC, &factory_doc);
+ writer->Append("OperationBuilder opBuilder = scope.graph().opBuilder(\""
+ + op.graph_name() + "\", scope.makeOpName(\""
+ + op_class.name() + "\"));");
+ writer->EndLine();
-OpGenerator::OpGenerator() : env(Env::Default()) {}
+ for (const OpSpec::Operand& input : op.inputs()) {
+ if (input.iterable()) {
+ writer->Append("opBuilder.addInputList(Operands.asOutputs("
+ + input.var().name() + "));");
+ writer->EndLine();
+ } else {
+ writer->Append("opBuilder.addInput(" + input.var().name()
+ + ".asOutput());");
+ writer->EndLine();
+ }
+ }
+ for (const OpSpec::Operand& attribute : op.attributes()) {
+ WriteSetAttrDirective(attribute, false, writer);
+ }
+ if (!op.options().empty()) {
+ writer->BeginBlock("if (options != null)")
+ .BeginBlock("for (Options opts : options)");
+ for (const OpSpec::Operand& option : op.options()) {
+ writer->BeginBlock("if (opts." + option.var().name() + " != null)");
+ WriteSetAttrDirective(option, true, writer);
+ writer->EndBlock();
+ }
+ writer->EndBlock().EndBlock();
+ }
+ writer->Append("return new ")
+ .AppendType(op_class)
+ .Append("(opBuilder.build());")
+ .EndLine();
+ writer->EndMethod();
+}
-OpGenerator::~OpGenerator() {}
+void RenderConstructor(const OpSpec& op, const Type& op_class,
+ SourceWriter* writer) {
+ Method constructor = Method::ConstructorFor(op_class)
+ .add_argument(
+ Variable::Create("operation",
+ Type::Class("Operation", "org.tensorflow")));
+ for (const OpSpec::Operand& output : op.outputs()) {
+ if (output.iterable() && !output.data_type().unknown()) {
+ constructor.add_annotation(
+ Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+ break;
+ }
+ }
+ writer->BeginMethod(constructor, PRIVATE)
+ .Append("super(operation);")
+ .EndLine();
+ if (op.outputs().size() > 0) {
+ writer->Append("int outputIdx = 0;")
+ .EndLine();
+ for (const OpSpec::Operand& output : op.outputs()) {
+ if (output.iterable()) {
+ string var_length = output.var().name() + "Length";
+ writer->Append("int " + var_length)
+ .Append(" = operation.outputListLength(\"" + output.graph_name()
+ + "\");")
+ .EndLine()
+ .Append(output.var().name() + " = Arrays.asList(");
+ if (!output.data_type().unknown()) {
+ writer->Append("(")
+ .AppendType(output.var().type().parameters().front())
+ .Append("[])");
+ }
+ writer->Append("operation.outputList(outputIdx, " + var_length + "));")
+ .EndLine()
+ .Append("outputIdx += " + var_length + ";")
+ .EndLine();
+ } else {
+ writer->Append(output.var().name()
+ + " = operation.output(outputIdx++);")
+ .EndLine();
+ }
+ }
+ }
+ writer->EndMethod();
+}
-Status OpGenerator::Run(const OpList& ops, const string& lib_name,
- const string& base_package, const string& output_dir) {
- const string package =
- base_package + '.' + str_util::StringReplace(lib_name, "_", "", true);
- const string package_path =
- output_dir + '/' + str_util::StringReplace(package, ".", "/", true);
- const string group = CamelCase(lib_name, '_', false);
+void RenderGettersAndSetters(const OpSpec& op, SourceWriter* writer) {
+ for (const OpSpec::Operand& option : op.options()) {
+ Method setter = Method::Create(option.var().name(), Type::Class("Options"))
+ .add_argument(option.var());
+ Javadoc setter_doc = Javadoc::Create()
+ .add_param_tag(option.var().name(), option.description());
+ writer->BeginMethod(setter, PUBLIC|STATIC, &setter_doc)
+ .Append("return new Options()." + option.var().name() + "("
+ + option.var().name() + ");")
+ .EndLine()
+ .EndMethod();
+ }
+ for (const OpSpec::Operand& output : op.outputs()) {
+ Method getter = Method::Create(output.var().name(), output.var().type());
+ Javadoc getter_doc = Javadoc::Create(output.description());
+ writer->BeginMethod(getter, PUBLIC, &getter_doc)
+ .Append("return " + output.var().name() + ";")
+ .EndLine()
+ .EndMethod();
+ }
+}
+
+void RenderInterfaceImpl(const OpSpec& op, RenderMode mode,
+ SourceWriter* writer) {
+ OpSpec::Operand output = op.outputs().front();
+
+ if (mode == SINGLE_OUTPUT) {
+ bool cast2obj = output.data_type().unknown();
+ Type return_type = Type::Class("Output", "org.tensorflow")
+ .add_parameter(cast2obj ? Type::Class("Object") : output.data_type());
+ Method as_output = Method::Create("asOutput", return_type)
+ .add_annotation(Annotation::Create("Override"));
+ if (cast2obj) {
+ as_output.add_annotation(
+ Annotation::Create("SuppressWarnings").attributes("\"unchecked\""));
+ }
+ writer->BeginMethod(as_output, PUBLIC);
+ if (cast2obj) {
+ writer->Append("return (").AppendType(return_type).Append(") ");
+ } else {
+ writer->Append("return ");
+ }
+ writer->Append(output.var().name() + ";")
+ .EndLine()
+ .EndMethod();
+
+ } else if (mode == SINGLE_LIST_OUTPUT) {
+ Type operand = Type::Interface("Operand", "org.tensorflow");
+ if (output.data_type().unknown()) {
+ operand.add_parameter(Type::Class("Object"));
+ } else {
+ operand.add_parameter(output.data_type());
+ }
+ Type return_type = Type::Interface("Iterator", "java.util")
+ .add_parameter(operand);
+ Method iterator = Method::Create("iterator", return_type)
+ .add_annotation(Annotation::Create("Override"))
+ .add_annotation(Annotation::Create("SuppressWarnings")
+ .attributes("{\"rawtypes\", \"unchecked\"}"));
+ // cast the output list using a raw List
+ writer->BeginMethod(iterator, PUBLIC)
+ .Append("return (" + return_type.name() + ") ")
+ .Append(output.var().name() + ".iterator();")
+ .EndLine()
+ .EndMethod();
+ }
+}
+
+void RenderOptionsClass(const OpSpec& op, SourceWriter* writer) {
+ Type options_class = Type::Class("Options");
+ Javadoc options_doc = Javadoc::Create(
+ "Class holding optional attributes of this operation");
+ writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc);
+ for (const OpSpec::Operand& option : op.options()) {
+ Method setter = Method::Create(option.var().name(), options_class)
+ .add_argument(option.var());
+ Javadoc setter_doc = Javadoc::Create()
+ .add_param_tag(option.var().name(), option.description());
+ writer->BeginMethod(setter, PUBLIC, &setter_doc)
+ .Append("this." + option.var().name() + " = " + option.var().name()
+ + ";")
+ .EndLine()
+ .Append("return this;")
+ .EndLine()
+ .EndMethod();
+ }
+ writer->EndLine();
+ for (const OpSpec::Operand& option : op.options()) {
+ writer->WriteField(option.var(), PRIVATE);
+ }
+ Method constructor = Method::ConstructorFor(options_class);
+ writer->BeginMethod(constructor, PRIVATE).EndMethod();
+ writer->EndType();
+}
- if (!env->FileExists(package_path).ok()) {
- TF_CHECK_OK(env->RecursivelyCreateDir(package_path));
+void RenderEndpoint(const OpSpec& op, const OpSpec::Endpoint& endpoint,
+ SourceWriter* writer) {
+ RenderMode mode = DEFAULT;
+ if (op.outputs().size() == 1) {
+ mode = op.outputs().front().iterable() ? SINGLE_LIST_OUTPUT : SINGLE_OUTPUT;
+ }
+ std::list<Type> dependencies;
+ CollectOpDependencies(op, mode, &dependencies);
+ const Type& op_class = endpoint.type();
+ writer->WriteFromFile(kLicenseSnippet)
+ .EndLine()
+ .Append("// This file is machine generated, DO NOT EDIT!")
+ .EndLine()
+ .EndLine()
+ .BeginType(op_class, PUBLIC|FINAL, &dependencies, &endpoint.javadoc());
+ if (!op.options().empty()) {
+ RenderOptionsClass(op, writer);
}
+ RenderFactoryMethod(op, op_class, writer);
+ RenderGettersAndSetters(op, writer);
+ if (mode != DEFAULT) {
+ RenderInterfaceImpl(op, mode, writer);
+ }
+ writer->EndLine();
+ for (const OpSpec::Operand& output : op.outputs()) {
+ writer->WriteField(output.var(), PRIVATE);
+ }
+ RenderConstructor(op, op_class, writer);
+ writer->EndType();
+}
+
+} // namespace
+
+OpGenerator::OpGenerator(const string& base_package, const string& output_dir,
+ const std::vector<string>& api_dirs, Env* env)
+ : base_package_(base_package), output_dir_(output_dir), api_dirs_(api_dirs),
+ env_(env) {
+}
+Status OpGenerator::Run(const OpList& op_list, const string& lib_name) {
LOG(INFO) << "Generating Java wrappers for '" << lib_name << "' operations";
- // TODO(karllessard) generate wrappers from list of ops
+ ApiDefMap api_map(op_list);
+ if (!api_dirs_.empty()) {
+ // Only load api files that correspond to the requested "op_list"
+ for (const auto& op : op_list.op()) {
+ for (const auto& api_def_dir : api_dirs_) {
+ const std::string api_def_file_pattern =
+ io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt");
+ if (env_->FileExists(api_def_file_pattern).ok()) {
+ TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern));
+ }
+ }
+ }
+ }
+ api_map.UpdateDocs();
+ for (const auto& op_def : op_list.op()) {
+ const ApiDef* api_def = api_map.GetApiDef(op_def.name());
+ if (api_def->visibility() != ApiDef::SKIP) {
+ Status status = GenerateOp(op_def, *api_def, lib_name);
+ if (status != Status::OK()) {
+ LOG(ERROR) << "Fail to generate Java wrapper for operation \""
+ << op_def.name() << "\"";
+ }
+ }
+ }
+ return Status::OK();
+}
+
+Status OpGenerator::GenerateOp(const OpDef& op_def, const ApiDef& api_def,
+ const string& lib_name) {
+ std::unique_ptr<OpSpec> op;
+ OpParser op_parser(op_def, api_def, lib_name, base_package_);
+ op_parser.Parse(&op);
+ for (const OpSpec::Endpoint& endpoint : op->endpoints()) {
+ string package_path = io::JoinPath(output_dir_,
+ str_util::StringReplace(endpoint.type().package(), ".", "/", true));
+ if (!env_->FileExists(package_path).ok()) {
+ TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(package_path));
+ }
+ string file_path =
+ io::JoinPath(package_path, endpoint.type().name() + ".java");
+ std::unique_ptr<tensorflow::WritableFile> file;
+ TF_CHECK_OK(env_->NewWritableFile(file_path, &file));
+ SourceFileWriter writer(file.get());
+ RenderEndpoint(*op, endpoint, &writer);
+ }
return Status::OK();
}
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 4b55ed3ed9..19d8db95fb 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -17,34 +17,42 @@ limitations under the License.
#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_GENERATOR_H_
#include <string>
+#include <vector>
-#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace java {
-/// \brief A generator of Java operation wrappers.
-///
-/// Such generator is normally ran only once per executable, outputting
-/// wrappers for the all registered operations it has been compiled with.
-/// Nonetheless, it is designed to support multiple runs, giving a different
-/// list of operations on each cycle.
+// A generator of Java operation wrappers.
+//
+// Such generator is normally ran only once per executable, outputting
+// wrappers for the all registered operations it has been compiled with.
+// Nonetheless, it is designed to support multiple runs, giving a different
+// list of operations on each cycle.
class OpGenerator {
public:
- OpGenerator();
- virtual ~OpGenerator();
+ OpGenerator(const string& base_package, const string& output_dir,
+ const std::vector<string>& api_dirs, Env* env = Env::Default());
+ virtual ~OpGenerator() = default;
- /// \brief Generates wrappers for the given list of 'ops'.
- ///
- /// Output files are generated in <output_dir>/<base_package>/<lib_package>,
- /// where 'lib_package' is derived from 'lib_name'.
- Status Run(const OpList& ops, const string& lib_name,
- const string& base_package, const string& output_dir);
+ // Generates wrappers for the given list of 'ops'.
+ //
+ // Output files are generated in <output_dir>/<base_package>/<lib_package>,
+ // where 'lib_package' is derived from 'lib_name'.
+ Status Run(const OpList& op_list, const string& lib_name);
private:
- Env* env;
+ string base_package_;
+ string output_dir_;
+ std::vector<string> api_dirs_;
+ Env* env_;
+
+ Status GenerateOp(const OpDef& op_def, const ApiDef& api_def,
+ const string& lib_name);
};
} // namespace java
diff --git a/tensorflow/java/src/gen/cc/op_parser.cc b/tensorflow/java/src/gen/cc/op_parser.cc
new file mode 100644
index 0000000000..0541e343d8
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_parser.cc
@@ -0,0 +1,417 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <map>
+#include <vector>
+#include <utility>
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/java/src/gen/cc/op_parser.h"
+
+namespace tensorflow {
+namespace java {
+namespace {
+
+string SnakeToCamelCase(const string& str, bool upper = false) {
+ string result;
+ bool cap = upper;
+ for (string::const_iterator it = str.begin(); it != str.end(); ++it) {
+ const char c = *it;
+ if (c == '_') {
+ cap = true;
+ } else if (cap) {
+ result += toupper(c);
+ cap = false;
+ } else {
+ result += c;
+ }
+ }
+ return result;
+}
+
+bool IsRealNumber(DataType type) {
+ for (DataType dt : RealNumberTypes()) {
+ if (type == dt) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsRealNumbers(const AttrValue& values) {
+ if (values.has_list()) {
+ for (int i = 0; i < values.list().type_size(); ++i) {
+ if (!IsRealNumber(values.list().type(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return IsRealNumber(values.type());
+}
+
+string ParseDocumentation(const string& text) {
+ std::stringstream javadoc_text;
+ string::const_iterator c_iter = text.cbegin();
+ bool code = false;
+ bool emphasis = false;
+ bool list = false;
+ while (c_iter != text.cend()) {
+ char c = *c_iter++;
+ int count = 1;
+ switch (c) {
+ case '\n':
+ if (!code) {
+ // consumes all subsequent newlines, if there are more than one,
+ // then there are two choices:
+ // - if the next line starts with an asterisk, we are enumerating
+ // a list of items
+ // - otherwise, we are starting a new paragraph
+ for (; c_iter != text.cend() && *c_iter == '\n'; ++count, ++c_iter) {}
+ if (c_iter != text.cend()) {
+ if (count > 1) {
+ if (*c_iter != '*' && list) {
+ javadoc_text << "</li>\n</ul>\n";
+ list = false;
+ } else if (*c_iter == '*' && !list) {
+ javadoc_text << "\n<ul>\n<li>";
+ list = true;
+ c_iter++;
+ } else {
+ javadoc_text << "\n<p>\n";
+ }
+ } else if (list && *c_iter == '*') {
+ javadoc_text << "</li>\n<li>";
+ c_iter++;
+ } else {
+ javadoc_text << '\n';
+ }
+ }
+ }
+ break;
+ case '`':
+ // consumes all subsequent backquotes, those are use enclose code.
+ // if there are more than 3, we are dealing with a pre-formatted block,
+ // otherwise it is a single-line code snippet
+ for (; c_iter != text.cend() && *c_iter == '`'; ++count, ++c_iter) {}
+ if (count >= 3) {
+ javadoc_text << (code ? "\n}</pre>" : "<pre>{@code\n");
+ } else {
+ javadoc_text << (code ? "}" : "{@code ");
+ }
+ code = !code;
+ break;
+ case '*':
+ if (!code) {
+ // consumes all subsequent asterisks, if there are more than one, then
+ // we put the text in bold, otherwise in italic
+ for (; c_iter != text.cend() && *c_iter == '*'; ++count, ++c_iter) {}
+ if (count > 1) {
+ javadoc_text << (emphasis ? "</b>" : "<b>");
+ } else {
+ javadoc_text << (emphasis ? "</i>" : "<i>");
+ }
+ emphasis = !emphasis;
+ } else {
+ javadoc_text << '*';
+ }
+ break;
+ default:
+ javadoc_text << c;
+ break;
+ }
+ }
+ return javadoc_text.str();
+}
+
+} // namespace
+
+OpParser::OpParser(const OpDef& op_def, const ApiDef& api_def,
+ const string& lib_name, const string& base_package)
+ : op_def_(op_def), op_api_(api_def), lib_name_(lib_name),
+ base_package_(base_package) {
+}
+
+void OpParser::Parse(std::unique_ptr<OpSpec>* op_ptr) {
+ visited_attrs_.clear();
+ next_generic_ = 'T';
+ op_ptr->reset(new OpSpec(op_api_.graph_op_name()));
+ for (const string& next_input_name : op_api_.arg_order()) {
+ for (int i = 0; i < op_def_.input_arg().size(); ++i) {
+ if (op_def_.input_arg(i).name() == next_input_name) {
+ ParseInput(op_def_.input_arg(i), op_api_.in_arg(i), op_ptr->get());
+ break;
+ }
+ }
+ }
+ for (int i = 0; i < op_def_.attr().size(); ++i) {
+ ParseAttribute(op_def_.attr(i), op_api_.attr(i), op_ptr->get());
+ }
+ for (int i = 0; i < op_def_.output_arg().size(); ++i) {
+ ParseOutput(op_def_.output_arg(i), op_api_.out_arg(i), op_ptr->get());
+ }
+ BuildEndpoints(op_ptr->get());
+}
+
+void OpParser::BuildEndpoints(OpSpec* op) {
+ Javadoc op_doc = Javadoc::Create(ParseDocumentation(op_api_.summary()))
+ .details(ParseDocumentation(op_api_.description()));
+ std::vector<Type> op_supertypes;
+ op_supertypes.push_back(Type::Class("PrimitiveOp", "org.tensorflow.op"));
+ std::map<string, const Type*> op_generics;
+ for (const OpSpec::Operand& output : op->outputs()) {
+ // declare generic output parameters at the Op class level
+ const Type& data_type = output.data_type();
+ if (data_type.kind() == Type::GENERIC && !data_type.unknown()
+ && op_generics.find(data_type.name()) == op_generics.end()) {
+ op_generics.insert(std::make_pair(data_type.name(), &data_type));
+ op_doc.add_param_tag("<" + data_type.name() + ">",
+ "data type of output '" + output.var().name() + "'");
+ }
+ // implement the Op as an (iteration of) Operand if it has only one output
+ if (op->outputs().size() == 1) {
+ Type operand_inf(Type::Interface("Operand", "org.tensorflow"));
+ operand_inf.add_parameter(data_type.unknown() ?
+ Type::Class("Object") : data_type);
+ op_supertypes.push_back(output.iterable() ?
+ Type::IterableOf(operand_inf) : operand_inf);
+ }
+ }
+ for (const auto& endpoint_def : op_api_.endpoint()) {
+ std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
+ // if the endpoint specifies a package, use it, otherwise derive it from the
+ // op library name.
+ string name;
+ string package;
+ if (name_tokens.size() > 1) {
+ package = str_util::Lowercase(name_tokens.at(0));
+ name = name_tokens.at(1);
+ } else {
+ package = str_util::StringReplace(lib_name_, "_", "", true);
+ name = name_tokens.at(0);
+ }
+ Type endpoint(Type::Class(name, base_package_ + "." + package));
+ Javadoc endpoint_doc(op_doc);
+ for (const auto& parameter : op_generics) {
+ endpoint.add_parameter(*parameter.second);
+ }
+ for (const Type& supertype : op_supertypes) {
+ endpoint.add_supertype(supertype);
+ }
+ if (endpoint_def.deprecation_version() > 0) {
+ string explanation;
+ if (op_api_.endpoint(0).deprecation_version() == 0) {
+ explanation = ", use {@link "
+ + op->endpoints().at(0).type().full_name()
+ + "} instead";
+ } else {
+ explanation = op_def_.deprecation().explanation();
+ }
+ endpoint_doc.add_tag("deprecated", explanation);
+ endpoint.add_annotation(Annotation::Create("Deprecated"));
+ }
+ // only visible ops should be annotated for exposure in the Ops Graph API
+ if (op_api_.visibility() != ApiDef::HIDDEN) {
+ string group_name = SnakeToCamelCase(lib_name_);
+ endpoint.add_annotation(
+ Annotation::Create("Operator", "org.tensorflow.op.annotation")
+ .attributes("group = \"" + group_name + "\""));
+ }
+ op->add_endpoint(endpoint, endpoint_doc);
+ }
+}
+
+void OpParser::ParseInput(const OpDef_ArgDef& input_def,
+ const ApiDef::Arg& input_api, OpSpec* op) {
+ bool iterable = false;
+ Type data_type = DataTypeOf(input_def, &iterable);
+ Type type = Type::Interface("Operand", "org.tensorflow")
+ .add_parameter(data_type);
+ if (iterable) {
+ type = Type::IterableOf(type);
+ }
+ op->add_input(OpSpec::Operand(input_api.name(),
+ Variable::Create(SnakeToCamelCase(input_api.rename_to()), type),
+ data_type,
+ ParseDocumentation(input_api.description()),
+ iterable));
+}
+
+void OpParser::ParseOutput(const OpDef_ArgDef& output_def,
+ const ApiDef::Arg& output_api, OpSpec* op) {
+ bool iterable = false;
+ Type data_type = DataTypeOf(output_def, &iterable);
+ Type type = Type::Class("Output", "org.tensorflow")
+ .add_parameter(data_type);
+ if (iterable) {
+ type = Type::ListOf(type);
+ }
+ op->add_output(OpSpec::Operand(output_api.name(),
+ Variable::Create(SnakeToCamelCase(output_api.rename_to()), type),
+ data_type,
+ ParseDocumentation(output_api.description()),
+ iterable));
+}
+
+void OpParser::ParseAttribute(const OpDef_AttrDef& attr_def,
+ const ApiDef::Attr& attr_api, OpSpec* op) {
+ // do not parse attributes already visited, they have probably been inferred
+ // before as an input argument type
+ if (visited_attrs_.find(attr_def.name()) != visited_attrs_.cend()) {
+ return;
+ }
+ bool iterable = false;
+ Type data_type = DataTypeOf(attr_def, &iterable);
+ // generic attributes should be passed as an explicit type
+ bool explicit_type = data_type.kind() == Type::GENERIC && !iterable;
+ Type type = explicit_type ?
+ Type::Class("Class").add_parameter(data_type) : data_type;
+ if (iterable) {
+ type = Type::ListOf(data_type);
+ }
+ OpSpec::Operand attr(attr_api.name(),
+ Variable::Create(SnakeToCamelCase(attr_api.rename_to()), type),
+ data_type,
+ ParseDocumentation(attr_api.description()),
+ iterable);
+ // attributes with a default value are optional
+ if (attr_api.has_default_value() && !explicit_type) {
+ op->add_option(attr);
+ } else {
+ op->add_attribute(attr);
+ }
+ visited_attrs_.insert(std::make_pair(attr_api.name(), data_type));
+}
+
+Type OpParser::DataTypeOf(const OpDef_ArgDef& arg, bool* iterable_out) {
+ if (!arg.number_attr().empty()) {
+ visited_attrs_.insert(std::make_pair(arg.number_attr(), Type::Int()));
+ *iterable_out = true;
+ }
+ if (arg.type() != DataType::DT_INVALID) {
+ // resolve type from DataType
+ switch (arg.type()) {
+ case DataType::DT_BOOL:
+ return Type::Class("Boolean");
+
+ case DataType::DT_STRING:
+ return Type::Class("String");
+
+ case DataType::DT_FLOAT:
+ return Type::Class("Float");
+
+ case DataType::DT_DOUBLE:
+ return Type::Class("Double");
+
+ case DataType::DT_UINT8:
+ return Type::Class("UInt8", "org.tensorflow.types");
+
+ case DataType::DT_INT32:
+ return Type::Class("Integer");
+
+ case DataType::DT_INT64:
+ return Type::Class("Long");
+
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ return Type::Wildcard();
+
+ default:
+ break;
+ }
+ } else {
+ // resolve type from type attribute
+ string attr_name = arg.type_attr();
+ if (attr_name.empty()) {
+ attr_name = arg.type_list_attr();
+ if (!attr_name.empty()) {
+ *iterable_out = true;
+ Type type = Type::Wildcard();
+ visited_attrs_.insert(std::make_pair(attr_name, type));
+ return type;
+ }
+ }
+ for (const auto& attr : op_def_.attr()) {
+ if (attr.name() == attr_name) {
+ Type type = DataTypeOf(attr, iterable_out);
+ visited_attrs_.insert(std::make_pair(attr_name, type));
+ return type;
+ }
+ }
+ }
+ LOG(WARNING) << "Data type for arg \"" << arg.name() << "\" is unknown";
+ return Type::Wildcard();
+}
+
+Type OpParser::DataTypeOf(const OpDef_AttrDef& attr, bool* iterable_out) {
+ std::map<string, Type>::const_iterator it = visited_attrs_.find(attr.name());
+ if (it != visited_attrs_.cend()) {
+ return it->second;
+ }
+ string attr_type = attr.type();
+ if (attr.type().compare(0, 5, "list(") == 0) {
+ attr_type = attr_type.substr(5, attr.type().find_last_of(')') - 5);
+ *iterable_out = true;
+ }
+ if (attr_type == "type") {
+ if (*iterable_out) {
+ return Type::Enum("DataType", "org.tensorflow");
+ }
+ return GetNextGenericTensorType(attr.allowed_values());
+ }
+ if (attr_type == "string") {
+ return Type::Class("String");
+ }
+ if (attr_type == "int") {
+ return Type::Class("Integer");
+ }
+ if (attr_type == "float") {
+ return Type::Class("Float");
+ }
+ if (attr_type == "bool") {
+ return Type::Class("Boolean");
+ }
+ if (attr_type == "shape") {
+ return Type::Class("Shape", "org.tensorflow");
+ }
+ if (attr_type == "tensor") {
+ return Type::Class("Tensor", "org.tensorflow")
+ .add_parameter(Type::Wildcard());
+ }
+ LOG(WARNING) << "Data type for attribute \"" << attr_type << "\" is unknown";
+ return *iterable_out ? Type::Wildcard() : Type::Class("Object");
+}
+
+Type OpParser::GetNextGenericTensorType(const AttrValue& allowed_values) {
+ Type generic = Type::Generic(string(1, next_generic_));
+ next_generic_ = (next_generic_ == 'Z') ? 'A' : next_generic_ + 1;
+
+ // when only real numbers are allowed, enforce that restriction in the Java by
+ // extending the generic from java.lang.Number
+ if (IsRealNumbers(allowed_values)) {
+ generic.add_supertype(Type::Class("Number"));
+ }
+ return generic;
+}
+
+} // namespace java
+} // namespace tensorflow
diff --git a/tensorflow/java/src/gen/cc/op_parser.h b/tensorflow/java/src/gen/cc/op_parser.h
new file mode 100644
index 0000000000..42855127cc
--- /dev/null
+++ b/tensorflow/java/src/gen/cc/op_parser.h
@@ -0,0 +1,137 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
+#define TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
+
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/api_def.pb.h"
+#include "tensorflow/java/src/gen/cc/java_defs.h"
+
+namespace tensorflow {
+namespace java {
+
+// Specification of a TensorFlow operation to generate.
+//
+// This is the result of an operation definition parsing, see OpParser::Parse().
+class OpSpec {
+ public:
+ class Endpoint {
+ public:
+ Endpoint(const Type& type, const Javadoc& javadoc)
+ : type_(type), javadoc_(javadoc) {}
+ const Type& type() const { return type_; }
+ const Javadoc& javadoc() const { return javadoc_; }
+
+ private:
+ Type type_;
+ Javadoc javadoc_;
+ };
+
+ class Operand {
+ public:
+ Operand(const string& graph_name, const Variable& var,
+ const Type& data_type, const string& description, bool iterable)
+ : graph_name_(graph_name), var_(var), data_type_(data_type),
+ description_(description), iterable_(iterable) {}
+ const string& graph_name() const { return graph_name_; }
+ const Variable& var() const { return var_; }
+ Variable* var_ptr() { return &var_; }
+ const Type& data_type() const { return data_type_; }
+ const string& description() const { return description_; }
+ bool iterable() const { return iterable_; }
+
+ private:
+ string graph_name_;
+ Variable var_;
+ Type data_type_;
+ string description_;
+ bool iterable_;
+ };
+
+ explicit OpSpec(const string& graph_name) : graph_name_(graph_name) {}
+ const string& graph_name() const { return graph_name_; }
+ const std::vector<Endpoint> endpoints() const { return endpoints_; }
+ void add_endpoint(const Type& type, const Javadoc& javadoc) {
+ endpoints_.push_back(Endpoint(type, javadoc));
+ }
+ const std::vector<Operand>& inputs() const { return inputs_; }
+ void add_input(const Operand& input) {
+ inputs_.push_back(input);
+ }
+ const std::vector<Operand>& outputs() const { return outputs_; }
+ void add_output(const Operand& output) {
+ outputs_.push_back(output);
+ }
+ const std::vector<Operand>& attributes() const { return attributes_; }
+ void add_attribute(const Operand& attribute) {
+ attributes_.push_back(attribute);
+ }
+ const std::vector<Operand>& options() const { return options_; }
+ void add_option(const Operand& option) {
+ options_.push_back(option);
+ }
+
+ private:
+ string graph_name_;
+ std::vector<Endpoint> endpoints_;
+ std::vector<Operand> inputs_;
+ std::vector<Operand> outputs_;
+ std::vector<Operand> attributes_;
+ std::vector<Operand> options_;
+};
+
+// A parser of ops proto definitions.
+//
+// This object parses the definition and the api of an TensorFlow operation to
+// produce a specification that can be used for Java source code rendering.
+class OpParser {
+ public:
+ OpParser(const OpDef& op_def, const ApiDef& api_def, const string& lib_name,
+ const string& base_package);
+ virtual ~OpParser() = default;
+
+ // Produces an operation specification from its proto definitions.
+ void Parse(std::unique_ptr<OpSpec>* op_ptr);
+
+ private:
+ OpDef op_def_;
+ ApiDef op_api_;
+ string lib_name_;
+ string base_package_;
+ std::map<string, Type> visited_attrs_;
+ char next_generic_ = 0;
+
+ void BuildEndpoints(OpSpec* op);
+ void ParseInput(const OpDef_ArgDef& input_def,
+ const ApiDef::Arg& input_api, OpSpec* op);
+ void ParseOutput(const OpDef_ArgDef& output_def,
+ const ApiDef::Arg& output_api, OpSpec* op);
+ void ParseAttribute(const OpDef_AttrDef& attr_def,
+ const ApiDef::Attr& attr_api, OpSpec* op);
+ Type DataTypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+ Type DataTypeOf(const OpDef_AttrDef& attr_def, bool *iterable_out);
+ Type GetNextGenericTensorType(const AttrValue& allowed_values);
+};
+
+} // namespace java
+} // namespace tensorflow
+
+#endif // TENSORFLOW_JAVA_SRC_GEN_CC_OP_PARSER_H_
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index a02f75ad6e..b1de5af6ba 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include <string>
#include <algorithm>
-#include <deque>
+#include <list>
#include "tensorflow/java/src/gen/cc/source_writer.h"
@@ -83,20 +83,20 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) {
}
SourceWriter& SourceWriter::AppendType(const Type& type) {
- if (type.kind() == Type::Kind::GENERIC && type.name().empty()) {
+ if (type.unknown()) {
Append("?");
} else {
Append(type.name());
- }
- if (!type.parameters().empty()) {
- Append("<");
- for (const Type& t : type.parameters()) {
- if (&t != &type.parameters().front()) {
- Append(", ");
+ if (!type.parameters().empty()) {
+ Append("<");
+ for (const Type& t : type.parameters()) {
+ if (&t != &type.parameters().front()) {
+ Append(", ");
+ }
+ AppendType(t);
}
- AppendType(t);
+ Append(">");
}
- Append(">");
}
return *this;
}
@@ -107,7 +107,21 @@ SourceWriter& SourceWriter::EndLine() {
return *this;
}
-SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
+SourceWriter& SourceWriter::BeginBlock(const string& expression) {
+ if (!expression.empty()) {
+ Append(expression + " {");
+ } else {
+ Append(newline_ ? "{" : " {");
+ }
+ return EndLine().Indent(2);
+}
+
+SourceWriter& SourceWriter::EndBlock() {
+ return Indent(-2).Append("}").EndLine();
+}
+
+SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers,
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
if (!method.constructor()) {
generic_namespace->Visit(method.return_type());
@@ -116,8 +130,9 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers) {
generic_namespace->Visit(v.type());
}
EndLine();
- WriteDoc(method.description(), method.return_description(),
- &method.arguments());
+ if (javadoc != nullptr) {
+ WriteJavadoc(*javadoc);
+ }
if (!method.annotations().empty()) {
WriteAnnotations(method.annotations());
}
@@ -145,29 +160,35 @@ SourceWriter& SourceWriter::EndMethod() {
return *this;
}
-SourceWriter& SourceWriter::BeginType(const Type& type,
- const std::list<Type>* dependencies, int modifiers) {
+SourceWriter& SourceWriter::BeginType(const Type& type, int modifiers,
+ const std::list<Type>* extra_dependencies, const Javadoc* javadoc) {
if (!type.package().empty()) {
Append("package ").Append(type.package()).Append(";").EndLine();
}
- if (dependencies != nullptr && !dependencies->empty()) {
- TypeImporter type_importer(type.package());
- for (const Type& t : *dependencies) {
+ TypeImporter type_importer(type.package());
+ type_importer.Visit(type);
+ if (extra_dependencies != nullptr) {
+ for (const Type& t : *extra_dependencies) {
type_importer.Visit(t);
}
+ }
+ if (!type_importer.imports().empty()) {
EndLine();
for (const string& s : type_importer.imports()) {
Append("import ").Append(s).Append(";").EndLine();
}
}
- return BeginInnerType(type, modifiers);
+ return BeginInnerType(type, modifiers, javadoc);
}
-SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers) {
+SourceWriter& SourceWriter::BeginInnerType(const Type& type, int modifiers,
+ const Javadoc* javadoc) {
GenericNamespace* generic_namespace = PushGenericNamespace(modifiers);
generic_namespace->Visit(type);
EndLine();
- WriteDoc(type.description());
+ if (javadoc != nullptr) {
+ WriteJavadoc(*javadoc);
+ }
if (!type.annotations().empty()) {
WriteAnnotations(type.annotations());
}
@@ -200,14 +221,15 @@ SourceWriter& SourceWriter::EndType() {
return *this;
}
-SourceWriter& SourceWriter::WriteFields(const std::list<Variable>& fields,
- int modifiers) {
- EndLine();
- for (const Variable& v : fields) {
- WriteModifiers(modifiers);
- AppendType(v.type()).Append(" ").Append(v.name()).Append(";");
- EndLine();
+SourceWriter& SourceWriter::WriteField(const Variable& field, int modifiers,
+ const Javadoc* javadoc) {
+ // If present, write field javadoc only as one brief line
+ if (javadoc != nullptr && !javadoc->brief().empty()) {
+ Append("/** ").Append(javadoc->brief()).Append(" */").EndLine();
}
+ WriteModifiers(modifiers);
+ AppendType(field.type()).Append(" ").Append(field.name()).Append(";");
+ EndLine();
return *this;
}
@@ -228,39 +250,33 @@ SourceWriter& SourceWriter::WriteModifiers(int modifiers) {
return *this;
}
-SourceWriter& SourceWriter::WriteDoc(const string& description,
- const string& return_description, const std::list<Variable>* parameters) {
- if (description.empty() && return_description.empty()
- && (parameters == nullptr || parameters->empty())) {
- return *this; // no doc to write
- }
+SourceWriter& SourceWriter::WriteJavadoc(const Javadoc& javadoc) {
+ Append("/**").Prefix(" * ").EndLine();
bool do_line_break = false;
- Append("/**").EndLine().Prefix(" * ");
- if (!description.empty()) {
- Write(description).EndLine();
+ if (!javadoc.brief().empty()) {
+ Write(javadoc.brief()).EndLine();
do_line_break = true;
}
- if (parameters != nullptr && !parameters->empty()) {
+ if (!javadoc.details().empty()) {
if (do_line_break) {
- EndLine();
- do_line_break = false;
- }
- for (const Variable& v : *parameters) {
- Append("@param ").Append(v.name());
- if (!v.description().empty()) {
- Append(" ").Write(v.description());
- }
- EndLine();
+ Append("<p>").EndLine();
}
+ Write(javadoc.details()).EndLine();
+ do_line_break = true;
}
- if (!return_description.empty()) {
+ if (!javadoc.tags().empty()) {
if (do_line_break) {
EndLine();
- do_line_break = false;
}
- Append("@return ").Write(return_description).EndLine();
+ for (const auto& p : javadoc.tags()) {
+ Append("@" + p.first);
+ if (!p.second.empty()) {
+ Append(" ").Write(p.second);
+ }
+ EndLine();
+ }
}
- return Prefix("").Append(" **/").EndLine();
+ return Prefix("").Append(" */").EndLine();
}
SourceWriter& SourceWriter::WriteAnnotations(
@@ -311,20 +327,19 @@ void SourceWriter::PopGenericNamespace() {
void SourceWriter::TypeVisitor::Visit(const Type& type) {
DoVisit(type);
for (const Type& t : type.parameters()) {
- DoVisit(t);
+ Visit(t);
}
for (const Annotation& t : type.annotations()) {
DoVisit(t);
}
for (const Type& t : type.supertypes()) {
- DoVisit(t);
+ Visit(t);
}
}
void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
// ignore non-generic parameters, wildcards and generics already declared
- if (type.kind() == Type::GENERIC
- && !type.IsWildcard()
+ if (type.kind() == Type::GENERIC && !type.unknown()
&& generic_names_.find(type.name()) == generic_names_.end()) {
declared_types_.push_back(&type);
generic_names_.insert(type.name());
@@ -333,7 +348,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) {
void SourceWriter::TypeImporter::DoVisit(const Type& type) {
if (!type.package().empty() && type.package() != current_package_) {
- imports_.insert(type.package() + '.' + type.name());
+ imports_.insert(type.full_name());
}
}
diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h
index f011acd30a..1f0febe9a3 100644
--- a/tensorflow/java/src/gen/cc/source_writer.h
+++ b/tensorflow/java/src/gen/cc/source_writer.h
@@ -93,25 +93,22 @@ class SourceWriter {
// This method appends a new opening brace to the current data and indent the
// next lines according to Google Java Style Guide. The block can optionally
// be preceded by an expression (e.g. Append("if(true)").BeginBlock();)
- SourceWriter& BeginBlock() {
- return Append(newline_ ? "{" : " {").EndLine().Indent(2);
- }
+ SourceWriter& BeginBlock(const string& expr = "");
// Ends the current block of source code.
//
// This method appends a new closing brace to the current data and outdent the
// next lines back to the margin used before BeginBlock() was invoked.
- SourceWriter& EndBlock() {
- return Indent(-2).Append("}").EndLine();
- }
+ SourceWriter& EndBlock();
// Begins to write a method.
//
// This method outputs the signature of the Java method from the data passed
- // in the 'method' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // method.
- SourceWriter& BeginMethod(const Method& method, int modifiers = 0);
+ // in the 'method' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the access scope of this method and, optionally,
+ // a Javadoc.
+ SourceWriter& BeginMethod(const Method& method, int modifiers,
+ const Javadoc* javadoc = nullptr);
// Ends the current method.
//
@@ -122,22 +119,24 @@ class SourceWriter {
// Begins to write the main type of a source file.
//
// This method outputs the declaration of the Java type from the data passed
- // in the 'type' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // type.
+ // in the 'type' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the access scope of this type and, optionally,
+ // a Javadoc.
//
- // If not null, all types found in the 'dependencies' list will be imported
- // before declaring the new type.
- SourceWriter& BeginType(const Type& clazz,
- const std::list<Type>* dependencies, int modifiers = 0);
+ // If not null, all types found in the 'extra_dependencies' list will be
+ // imported before declaring the new type.
+ SourceWriter& BeginType(const Type& clazz, int modifiers,
+ const std::list<Type>* extra_dependencies = nullptr,
+ const Javadoc* javadoc = nullptr);
// Begins to write a new inner type.
//
// This method outputs the declaration of the Java type from the data passed
- // in the 'type' parameter and starts a new block. Additionnal modifiers can
- // also be passed in parameter to define the accesses and the scope of this
- // type.
- SourceWriter& BeginInnerType(const Type& type, int modifiers = 0);
+ // in the 'type' parameter and starts a new block. Modifiers are also passed
+ // in parameter to define the accesses and the scope of this type and,
+ // optionally, a Javadoc.
+ SourceWriter& BeginInnerType(const Type& type, int modifiers,
+ const Javadoc* javadoc = nullptr);
// Ends the current type.
//
@@ -145,13 +144,13 @@ class SourceWriter {
// BeginType() or BeginInnerType() prior to this.
SourceWriter& EndType();
- // Writes a list of variables as fields of a type.
+ // Writes a variable as fields of a type.
//
// This method must be called within the definition of a type (see BeginType()
- // or BeginInnerType()). Additional modifiers can also be passed in parameter
- // to define the accesses and the scope of those fields.
- SourceWriter& WriteFields(const std::list<Variable>& fields,
- int modifiers = 0);
+ // or BeginInnerType()). Modifiers are also be passed in parameter to define
+ // the accesses and the scope of this field and, optionally, a Javadoc.
+ SourceWriter& WriteField(const Variable& field, int modifiers,
+ const Javadoc* javadoc = nullptr);
protected:
virtual void DoAppend(const StringPiece& str) = 0;
@@ -207,9 +206,7 @@ class SourceWriter {
std::stack<GenericNamespace*> generic_namespaces_;
SourceWriter& WriteModifiers(int modifiers);
- SourceWriter& WriteDoc(const string& description,
- const string& return_description = "",
- const std::list<Variable>* parameters = nullptr);
+ SourceWriter& WriteJavadoc(const Javadoc& javadoc);
SourceWriter& WriteAnnotations(const std::list<Annotation>& annotations);
SourceWriter& WriteGenerics(const std::list<const Type*>& generics);
GenericNamespace* PushGenericNamespace(int modifiers);
diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc
index 4bce2fea70..8bd42d9d0e 100644
--- a/tensorflow/java/src/gen/cc/source_writer_test.cc
+++ b/tensorflow/java/src/gen/cc/source_writer_test.cc
@@ -250,7 +250,7 @@ TEST(StreamTest, Types) {
.AppendType(generic).Append(", ")
.AppendType(Type::ListOf(generic)).Append(", ")
.AppendType(Type::ListOf(Type::IterableOf(generic))).Append(", ")
- .AppendType(Type::ListOf(Type::Generic()));
+ .AppendType(Type::ListOf(Type::Wildcard()));
const char* expected =
"int, String, T, List<T>, List<Iterable<T>>, List<?>";
@@ -282,7 +282,7 @@ TEST(WriteType, SimpleClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -300,7 +300,7 @@ TEST(WriteType, SimpleClassWithDependencies) {
deps.push_back(Type::Class("SamePackageType", "org.tensorflow"));
deps.push_back(Type::Class("NoPackageType"));
- writer.BeginType(clazz, &deps, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC, &deps).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -313,18 +313,21 @@ TEST(WriteType, SimpleClassWithDependencies) {
TEST(WriteType, AnnotatedAndDocumentedClass) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
- clazz.description("This class has a\n<p>\nmultiline description.");
+ Javadoc clazz_doc;
+ clazz_doc.brief("Javadoc test")
+ .details("This is a\nmultiline description.");
clazz.add_annotation(Annotation::Create("Bean"));
clazz.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC, nullptr, &clazz_doc).EndType();
const char* expected =
"package org.tensorflow;\n\n"
"/**\n"
- " * This class has a\n"
+ " * Javadoc test\n"
" * <p>\n"
+ " * This is a\n"
" * multiline description.\n"
" **/\n"
"@Bean\n"
@@ -339,7 +342,7 @@ TEST(WriteType, ParameterizedClass) {
clazz.add_parameter(Type::Generic("T"));
clazz.add_parameter(Type::Generic("U").add_supertype(Type::Class("Number")));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -358,7 +361,7 @@ TEST(WriteType, ParameterizedClassAndSupertypes) {
clazz.add_supertype(Type::Interface("Runnable"));
clazz.add_supertype(Type::Class("SuperTest").add_parameter(type_t));
- writer.BeginType(clazz, nullptr, PUBLIC).EndType();
+ writer.BeginType(clazz, PUBLIC).EndType();
const char* expected =
"package org.tensorflow;\n\n"
@@ -372,24 +375,24 @@ TEST(WriteType, ParameterizedClassFields) {
Type clazz = Type::Class("Test", "org.tensorflow");
Type type_t = Type::Generic("T").add_supertype(Type::Class("Number"));
clazz.add_parameter(type_t);
- std::list<Variable> static_fields;
- static_fields.push_back(Variable::Create("field1", Type::Class("String")));
- std::list<Variable> member_fields;
- member_fields.push_back(Variable::Create("field2", Type::Class("String")));
- member_fields.push_back(Variable::Create("field3", type_t));
-
- writer.BeginType(clazz, nullptr, PUBLIC)
- .WriteFields(static_fields, STATIC | PUBLIC | FINAL)
- .WriteFields(member_fields, PRIVATE)
+ Variable field1 = Variable::Create("field1", Type::Class("String"));
+ Variable field2 = Variable::Create("field2", Type::Class("String"));
+ Variable field3 = Variable::Create("field3", type_t);
+ Javadoc field3_doc;
+ field3_doc.brief("This variable is documented");
+
+ writer.BeginType(clazz, PUBLIC)
+ .WriteField(field1, STATIC | PUBLIC | FINAL)
+ .WriteField(field2, PRIVATE)
+ .WriteField(field3, PRIVATE, &field3_doc)
.EndType();
const char* expected =
"package org.tensorflow;\n\n"
"public class Test<T extends Number> {\n"
- " \n"
" public static final String field1;\n"
- " \n"
" private String field2;\n"
+ " /** This variable is documented */\n"
" private T field3;\n"
"}\n";
ASSERT_STREQ(expected, writer.str().data());
@@ -400,7 +403,7 @@ TEST(WriteType, SimpleInnerClass) {
Type clazz = Type::Class("Test", "org.tensorflow");
Type inner_class = Type::Class("InnerTest");
- writer.BeginType(clazz, nullptr, PUBLIC)
+ writer.BeginType(clazz, PUBLIC)
.BeginInnerType(inner_class, PUBLIC)
.EndType()
.EndType();
@@ -423,7 +426,7 @@ TEST(WriteType, StaticParameterizedInnerClass) {
Type inner_class = Type::Class("InnerTest");
inner_class.add_parameter(type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
+ writer.BeginType(clazz, PUBLIC)
.BeginInnerType(inner_class, PUBLIC | STATIC)
.EndType()
.EndType();
@@ -443,7 +446,7 @@ TEST(WriteMethod, SimpleMethod) {
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- writer.BeginType(clazz, nullptr, PUBLIC)
+ writer.BeginType(clazz, PUBLIC)
.BeginMethod(method, PUBLIC).EndMethod()
.EndType();
@@ -461,13 +464,15 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
Method method = Method::Create("doNothing", Type::Void());
- method.description("This method has a\n<p>\nmultiline description.");
+ Javadoc method_doc;
+ method_doc.brief("Javadoc test")
+ .details("This method has a\nmultiline description.");
method.add_annotation(Annotation::Create("Override"));
method.add_annotation(Annotation::Create("SuppressWarnings")
.attributes("\"rawtypes\""));
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC).EndMethod()
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC, &method_doc).EndMethod()
.EndType();
const char* expected =
@@ -475,8 +480,9 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
"public class Test {\n"
" \n"
" /**\n"
- " * This method has a\n"
+ " * Javadoc test\n"
" * <p>\n"
+ " * This method has a\n"
" * multiline description.\n"
" **/\n"
" @Override\n"
@@ -490,16 +496,18 @@ TEST(WriteMethod, AnnotatedAndDocumentedMethod) {
TEST(WriteMethod, DocumentedMethodWithArguments) {
SourceBufferWriter writer;
Type clazz = Type::Class("Test", "org.tensorflow");
+ Variable reverse = Variable::Create("reverse", Type::Boolean());
Method method = Method::Create("boolToInt", Type::Int());
- method.description("Converts a boolean to an int");
- method.return_description("int value for this boolean");
method.add_argument(Variable::Create("b", Type::Boolean()));
- Variable reverse = Variable::Create("reverse", Type::Boolean());
- reverse.description("if true, value is reversed");
method.add_argument(reverse);
-
- writer.BeginType(clazz, nullptr, PUBLIC)
- .BeginMethod(method, PUBLIC)
+ Javadoc method_doc;
+ method_doc.brief("Converts a boolean to an int")
+ .details("This method will convert\na boolean to an int")
+ .add_param_tag(reverse.name(), "if true, value is reversed")
+ .add_tag("return", "int value for this boolean");
+
+ writer.BeginType(clazz, PUBLIC)
+ .BeginMethod(method, PUBLIC, &method_doc)
.Append("if (b && !reverse)")
.BeginBlock()
.Append("return 1;").EndLine()
@@ -514,8 +522,10 @@ TEST(WriteMethod, DocumentedMethodWithArguments) {
" \n"
" /**\n"
" * Converts a boolean to an int\n"
+ " * <p>\n"
+ " * This method will convert\n"
+ " * a boolean to an int\n"
" * \n"
- " * @param b\n"
" * @param reverse if true, value is reversed\n"
" * @return int value for this boolean\n"
" **/\n"
@@ -536,7 +546,7 @@ TEST(WriteMethod, ParameterizedMethod) {
clazz.add_parameter(type_t);
Method method = Method::Create("doNothing", type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
+ writer.BeginType(clazz, PUBLIC)
.BeginMethod(method, PUBLIC)
.Append("return null;").EndLine()
.EndMethod()
@@ -560,7 +570,7 @@ TEST(WriteMethod, StaticParameterizedMethod) {
clazz.add_parameter(type_t);
Method method = Method::Create("doNothing", type_t);
- writer.BeginType(clazz, nullptr, PUBLIC)
+ writer.BeginType(clazz, PUBLIC)
.BeginMethod(method, PUBLIC | STATIC)
.Append("return null;").EndLine()
.EndMethod()
diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl
index a6650fc4ea..1e7899cf7a 100644
--- a/tensorflow/java/src/gen/gen_ops.bzl
+++ b/tensorflow/java/src/gen/gen_ops.bzl
@@ -1,9 +1,11 @@
# -*- Python -*-
-load("//tensorflow:tensorflow.bzl",
- "tf_binary_additional_srcs",
- "tf_cc_binary",
- "tf_copts")
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_binary_additional_srcs",
+ "tf_cc_binary",
+ "tf_copts",
+)
# Given a list of "ops_libs" (a list of files in the core/ops directory
# without their .cc extensions), generate Java wrapper code for all operations
@@ -27,16 +29,31 @@ def tf_java_op_gen_srcjar(name,
ops_libs_pkg="//tensorflow/core",
out_dir="ops/",
out_src_dir="src/main/java/",
+ api_def_srcs=[],
visibility=["//tensorflow/java:__pkg__"]):
gen_tools = []
gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files
+ srcs = api_def_srcs[:]
# Construct an op generator binary for each ops library.
for ops_lib in ops_libs:
gen_lib = ops_lib[:ops_lib.rfind("_")]
out_gen_tool = out_dir + ops_lib + "_gen_tool"
+ if not api_def_srcs:
+ api_def_args_str = ","
+ else:
+ api_def_args = []
+ for api_def_src in api_def_srcs:
+ # Add directory of the first ApiDef source to args.
+ # We are assuming all ApiDefs in a single api_def_src are in the
+ # same directory.
+ api_def_args.append(
+ " $$(dirname $$(echo $(locations " + api_def_src +
+ ") | cut -d\" \" -f1))")
+ api_def_args_str = ",".join(api_def_args)
+
tf_cc_binary(
name=out_gen_tool,
copts=tf_copts(),
@@ -48,7 +65,8 @@ def tf_java_op_gen_srcjar(name,
gen_cmds += ["$(location :" + out_gen_tool + ")" +
" --output_dir=$(@D)/" + out_src_dir +
" --lib_name=" + gen_lib +
- " --base_package=" + gen_base_package]
+ " --base_package=" + gen_base_package +
+ " " + api_def_args_str]
# Generate a source archive containing generated code for these ops.
gen_srcjar = out_dir + name + ".srcjar"
@@ -57,6 +75,7 @@ def tf_java_op_gen_srcjar(name,
gen_tools += tf_binary_additional_srcs()
native.genrule(
name=name,
+ srcs=srcs,
outs=[gen_srcjar],
tools=gen_tools,
cmd="&&".join(gen_cmds))
diff --git a/tensorflow/java/src/gen/resources/license.snippet.java b/tensorflow/java/src/gen/resources/license.snippet.java
new file mode 100644
index 0000000000..90285ec669
--- /dev/null
+++ b/tensorflow/java/src/gen/resources/license.snippet.java
@@ -0,0 +1,14 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/