diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-05-03 22:39:35 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-05-03 22:40:44 -0400 |
commit | dd1ef8fa8f6861e53e8a7953c171b3e9253043ed (patch) | |
tree | 92fa94bb358d2a43ccabaa664abcb0abbc0d113b /tensorflow/java/src | |
parent | eac1479f04181fb107c85af29a709eb369831972 (diff) |
Second code review
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r-- | tensorflow/java/src/gen/cc/java_defs.h | 6 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_gen_main.cc | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.cc | 77 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.h | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.cc | 25 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.h | 17 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer.cc | 20 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer.h | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer_test.cc | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/gen_ops.bzl | 41 |
10 files changed, 108 insertions, 86 deletions
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index 81ac67eb2f..62575f6683 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -102,10 +102,10 @@ class Type { const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } - const string full_name() const { + const string canonical_name() const { return package_.empty() ? name_ : package_ + "." + name_; } - bool unknown() const { return name_.empty(); } // only wildcards has no name + bool wildcard() const { return name_.empty(); } // only wildcards has no name const std::list<Type>& parameters() const { return parameters_; } Type& add_parameter(const Type& parameter) { parameters_.push_back(parameter); diff --git a/tensorflow/java/src/gen/cc/op_gen_main.cc b/tensorflow/java/src/gen/cc/op_gen_main.cc index a508c96516..6c35cd9595 100644 --- a/tensorflow/java/src/gen/cc/op_gen_main.cc +++ b/tensorflow/java/src/gen/cc/op_gen_main.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2327a4daf1..7355b3a395 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ limitations under the License. #include <list> #include <memory> #include <set> -#include <ctime> #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -39,13 +38,26 @@ namespace { const char* kLicenseSnippet = "tensorflow/java/src/gen/resources/license.java.snippet"; +// There is three different modes to render an op class, depending on the +// number and type of outputs it has: +// +// DEFAULT: This mode does not provide any specialization for the op class, it +// is applied when the operation does not comply with any other mode +// +// OPERAND: The op class implements the Operand<T> interface, allowing an +// instance to be passed directly in input to another operation +// +// LIST_OPERAND: The op class implements the Iterable<Operand<T>> interface, +// allowing an instance to be passed directly as a list input to +// another operation +// enum RenderMode { DEFAULT, - SINGLE_OUTPUT, - SINGLE_LIST_OUTPUT + OPERAND, + LIST_OPERAND }; -inline void AddArgument(const Variable& var, const string& description, +void AddArgument(const Variable& var, const string& description, Method* method_out, Javadoc* javadoc_out) { method_out->add_argument(var); javadoc_out->add_param_tag(var.name(), description); @@ -56,9 +68,9 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, out->push_back(Type::Class("Operation", "org.tensorflow")); out->push_back(Type::Class("OperationBuilder", "org.tensorflow")); out->push_back(Type::Class("Scope", "org.tensorflow.op")); - if (mode == SINGLE_OUTPUT) { + if (mode == OPERAND) { out->push_back(Type::Class("Output", "org.tensorflow")); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { out->push_back(Type::Interface("Iterator", "java.util")); } // Don't pay attention to duplicate types in the dependency list, they will @@ -180,7 +192,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, Variable::Create("operation", Type::Class("Operation", "org.tensorflow")); Method constructor = Method::ConstructorFor(op_class).add_argument(operation); for (const ArgumentSpec& output : op.outputs()) { - if (output.iterable() && !output.type().unknown()) { + if (output.iterable() && !output.type().wildcard()) { constructor.add_annotation( Annotation::Create("SuppressWarnings").attributes("\"unchecked\"")); break; @@ -200,7 +212,7 @@ void RenderConstructor(const OpSpec& op, const Type& op_class, + "\");") .EndLine() .Append(output.var().name() + " = Arrays.asList("); - if (!output.type().unknown()) { + if (!output.type().wildcard()) { writer->Append("(") .AppendType(output.var().type().parameters().front()) .Append("[])"); @@ -245,8 +257,8 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, SourceWriter* writer) { ArgumentSpec output = op.outputs().front(); - if (mode == SINGLE_OUTPUT) { - bool cast2obj = output.type().unknown(); + if (mode == OPERAND) { + bool cast2obj = output.type().wildcard(); Type return_type = Type::Class("Output", "org.tensorflow") .add_parameter(cast2obj ? Type::Class("Object") : output.type()); Method as_output = Method::Create("asOutput", return_type) @@ -265,9 +277,9 @@ void RenderInterfaceImpl(const OpSpec& op, RenderMode mode, .EndLine() .EndMethod(); - } else if (mode == SINGLE_LIST_OUTPUT) { + } else if (mode == LIST_OPERAND) { Type operand = Type::Interface("Operand", "org.tensorflow"); - if (output.type().unknown()) { + if (output.type().wildcard()) { operand.add_parameter(Type::Class("Object")); } else { operand.add_parameter(output.type()); @@ -291,7 +303,7 @@ void RenderOptionsClass(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Type options_class = Type::Class("Options"); Javadoc options_doc = Javadoc::Create( - "Optional attributes for {@link " + op_class.full_name() + "}"); + "Optional attributes for {@link " + op_class.canonical_name() + "}"); writer->BeginInnerType(options_class, PUBLIC | STATIC, &options_doc); for (const AttributeSpec& attr : op.optional_attributes()) { Method setter = Method::Create(attr.var().name(), options_class); @@ -319,8 +331,7 @@ inline Type ClassOf(const EndpointSpec& endpoint, const string& base_package) { } void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, - const string& base_package, const string& output_dir, Env* env, - const std::tm* timestamp) { + const string& base_package, const string& output_dir, Env* env) { Type op_class(ClassOf(endpoint, base_package) .add_supertype(Type::Class("PrimitiveOp", "org.tensorflow.op"))); Javadoc op_javadoc(endpoint.javadoc()); @@ -329,22 +340,22 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, RenderMode mode = DEFAULT; if (op.outputs().size() == 1) { const ArgumentSpec& output = op.outputs().front(); - Type operand_type(output.type().unknown() ? + Type operand_type(output.type().wildcard() ? Type::Class("Object") : output.type()); Type operand_inf(Type::Interface("Operand", "org.tensorflow") .add_parameter(operand_type)); if (output.iterable()) { - mode = SINGLE_LIST_OUTPUT; + mode = LIST_OPERAND; op_class.add_supertype(Type::IterableOf(operand_inf)); } else { - mode = SINGLE_OUTPUT; + mode = OPERAND; op_class.add_supertype(operand_inf); } } // op generic parameters std::set<string> generics; for (const ArgumentSpec& output : op.outputs()) { - if (output.type().kind() == Type::GENERIC && !output.type().unknown() + if (output.type().kind() == Type::GENERIC && !output.type().wildcard() && generics.find(output.type().name()) == generics.end()) { op_class.add_parameter(output.type()); op_javadoc.add_param_tag("<" + output.type().name() + ">", @@ -353,16 +364,15 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } } // op annotations - char date[20]; - strftime(date, sizeof date, "%FT%TZ", timestamp); - op_class.add_annotation(Annotation::Create("Generated", "javax.annotation") - .attributes(string("value = \"op_generator\", date = \"") + date + "\"")); + op_class.add_annotation( + Annotation::Create("Generated", "javax.annotation") + .attributes("value = \"TensorFlow Java Op Generator\"")); if (endpoint.deprecated()) { op_class.add_annotation(Annotation::Create("Deprecated")); string explanation; if (!op.endpoints().front().deprecated()) { explanation = "use {@link " + - ClassOf(op.endpoints().front(), base_package).full_name() + ClassOf(op.endpoints().front(), base_package).canonical_name() + "} instead"; } else { explanation = op.deprecation_explanation(); @@ -376,14 +386,16 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, .attributes("group = \"" + endpoint.package() + "\"")); } // create op class file - string op_dir = io::JoinPath(output_dir, + const string op_dir_name = io::JoinPath(output_dir, str_util::StringReplace(op_class.package(), ".", "/", true)); - if (!env->FileExists(op_dir).ok()) { - TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir)); + if (!env->FileExists(op_dir_name).ok()) { + TF_CHECK_OK(Env::Default()->RecursivelyCreateDir(op_dir_name)) + << op_dir_name; } + const string op_file_name = op_class.name() + ".java"; std::unique_ptr<tensorflow::WritableFile> op_file; TF_CHECK_OK(env->NewWritableFile( - io::JoinPath(op_dir, op_class.name() + ".java"), &op_file)); + io::JoinPath(op_dir_name, op_file_name), &op_file)) << op_file_name; // render endpoint source code SourceFileWriter writer(op_file.get()); @@ -420,20 +432,19 @@ Status OpGenerator::Run(const OpList& op_list, const string& base_package, const std::string api_def_file_pattern = io::JoinPath(api_def_dir, "api_def_" + op.name() + ".pbtxt"); if (env_->FileExists(api_def_file_pattern).ok()) { - TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)); + TF_CHECK_OK(api_map.LoadFile(env_, api_def_file_pattern)) + << api_def_file_pattern; } } } } api_map.UpdateDocs(); - time_t now; - time(&now); for (const auto& op_def : op_list.op()) { const ApiDef* api_def = api_map.GetApiDef(op_def.name()); if (api_def->visibility() != ApiDef::SKIP) { OpSpec op(OpSpec::Create(op_def, *api_def)); for (const EndpointSpec& endpoint : op.endpoints()) { - GenerateOp(op, endpoint, base_package, output_dir, env_, gmtime(&now)); + GenerateOp(op, endpoint, base_package, output_dir, env_); } } } diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index b789e11fa9..cfe842070a 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index dcc6388614..081062ceaf 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -45,9 +45,26 @@ class TypeResolver { public: explicit TypeResolver(const OpDef& op_def) : op_def_(op_def) {} + // Returns the class type of an input/output argument + // + // For example, if the argument's datatype is DT_STRING, this method will + // return "java.lang.String", so the argument can become "Operand<String>" + // in the Ops API Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); - std::pair<Type, Type> TypeOf(const OpDef_AttrDef& attr_def, + + // Returns types of an input attribute + // + // The first element of the pair is the class type of this attribute while + // the second is its JNI/primitive type equivalent, required for explicit + // unboxing. + // + // For example, if the attribute is of type "float", this method will return + // <java.lang.Float, float>, so the attribute can be used as a "Float" object + // in the Ops API and casted to a "float" when passing through the JNI layer. + std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def, bool *iterable_out); + + // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { return visited_attrs_.find(attr_name) != visited_attrs_.cend(); } @@ -123,7 +140,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { for (const auto& attr_def : op_def_.attr()) { if (attr_def.name() == arg_def.type_attr()) { - type = TypeOf(attr_def, iterable_out).first; + type = TypesOf(attr_def, iterable_out).first; break; } } @@ -141,7 +158,7 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, return type; } -std::pair<Type, Type> TypeResolver::TypeOf(const OpDef_AttrDef& attr_def, +std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, bool* iterable_out) { std::pair<Type, Type> types = MakeTypePair(Type::Wildcard()); *iterable_out = false; @@ -319,7 +336,7 @@ ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { bool iterable = false; - std::pair<Type, Type> types = type_resolver->TypeOf(attr_def, &iterable); + std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC ? Type::Class("Class").add_parameter(types.first) : types.first; if (iterable) { diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 7d64391446..81582ea207 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -65,7 +65,6 @@ class ArgumentSpec { const Type& type, const string& description, bool iterable) : op_def_name_(op_def_name), var_(var), type_(type), description_(description), iterable_(iterable) {} - virtual ~ArgumentSpec() = default; const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -81,7 +80,7 @@ class ArgumentSpec { const bool iterable_; }; -class AttributeSpec : public ArgumentSpec { +class AttributeSpec { public: // A specification for an operation attribute // @@ -95,14 +94,24 @@ class AttributeSpec : public ArgumentSpec { AttributeSpec(const string& op_def_name, const Variable& var, const Type& type, const Type& jni_type, const string& description, bool iterable, bool has_default_value) - : ArgumentSpec(op_def_name, var, type, description, iterable), + : op_def_name_(op_def_name), var_(var), type_(type), + description_(description), iterable_(iterable), jni_type_(jni_type), has_default_value_(has_default_value) {} - virtual ~AttributeSpec() = default; + const string& op_def_name() const { return op_def_name_; } + const Variable& var() const { return var_; } + const Type& type() const { return type_; } + const string& description() const { return description_; } + bool iterable() const { return iterable_; } const Type& jni_type() const { return jni_type_; } bool has_default_value() const { return has_default_value_; } private: + const string op_def_name_; + const Variable var_; + const Type type_; + const string description_; + const bool iterable_; const Type jni_type_; const bool has_default_value_; }; diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 7e427787f9..56806cbb6d 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -83,17 +83,19 @@ SourceWriter& SourceWriter::Append(const StringPiece& str) { } SourceWriter& SourceWriter::AppendType(const Type& type) { - if (type.unknown()) { + if (type.wildcard()) { Append("?"); } else { Append(type.name()); if (!type.parameters().empty()) { Append("<"); + bool first = true; for (const Type& t : type.parameters()) { - if (&t != &type.parameters().front()) { + if (!first) { Append(", "); } AppendType(t); + first = false; } Append(">"); } @@ -145,11 +147,13 @@ SourceWriter& SourceWriter::BeginMethod(const Method& method, int modifiers, AppendType(method.return_type()).Append(" "); } Append(method.name()).Append("("); + bool first = true; for (const Variable& v : method.arguments()) { - if (&v != &method.arguments().front()) { + if (!first) { Append(", "); } AppendType(v.type()).Append(v.variadic() ? "... " : " ").Append(v.name()); + first = false; } return Append(")").BeginBlock(); } @@ -294,14 +298,16 @@ SourceWriter& SourceWriter::WriteAnnotations( SourceWriter& SourceWriter::WriteGenerics( const std::list<const Type*>& generics) { Append("<"); + bool first = true; for (const Type* pt : generics) { - if (pt != generics.front()) { + if (!first) { Append(", "); } Append(pt->name()); if (!pt->supertypes().empty()) { Append(" extends ").AppendType(pt->supertypes().front()); } + first = false; } return Append(">"); } @@ -339,7 +345,7 @@ void SourceWriter::TypeVisitor::Visit(const Type& type) { void SourceWriter::GenericNamespace::DoVisit(const Type& type) { // ignore non-generic parameters, wildcards and generics already declared - if (type.kind() == Type::GENERIC && !type.unknown() + if (type.kind() == Type::GENERIC && !type.wildcard() && generic_names_.find(type.name()) == generic_names_.end()) { declared_types_.push_back(&type); generic_names_.insert(type.name()); @@ -348,7 +354,7 @@ void SourceWriter::GenericNamespace::DoVisit(const Type& type) { void SourceWriter::TypeImporter::DoVisit(const Type& type) { if (!type.package().empty() && type.package() != current_package_) { - imports_.insert(type.full_name()); + imports_.insert(type.canonical_name()); } } diff --git a/tensorflow/java/src/gen/cc/source_writer.h b/tensorflow/java/src/gen/cc/source_writer.h index bcae33ccce..1f0febe9a3 100644 --- a/tensorflow/java/src/gen/cc/source_writer.h +++ b/tensorflow/java/src/gen/cc/source_writer.h @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/cc/source_writer_test.cc b/tensorflow/java/src/gen/cc/source_writer_test.cc index 875ad99ae2..b9a5fee9be 100644 --- a/tensorflow/java/src/gen/cc/source_writer_test.cc +++ b/tensorflow/java/src/gen/cc/source_writer_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/tensorflow/java/src/gen/gen_ops.bzl b/tensorflow/java/src/gen/gen_ops.bzl index 7017b52649..f4ff34ea03 100644 --- a/tensorflow/java/src/gen/gen_ops.bzl +++ b/tensorflow/java/src/gen/gen_ops.bzl @@ -3,33 +3,26 @@ load( "//tensorflow:tensorflow.bzl", "tf_binary_additional_srcs", - "tf_cc_binary", - "tf_copts", ) -# Given a list of "ops_libs" (a list of files in the core/ops directory -# without their .cc extensions), generate Java wrapper code for all operations -# found in the ops files. -# Then, combine all those source files into a single archive (.srcjar). +# Generate Java wrapper classes for all registered core operations and package +# them into a single source archive (.srcjar). # # For example: -# tf_java_op_gen_srcjar("gen_sources", "gen_tool", "my.package", [ "array_ops", "math_ops" ]) +# tf_java_op_gen_srcjar("gen_sources", ":gen_tool", "my.package") # -# will create a genrule named "gen_sources" that first generate source files: -# ops/src/main/java/my/package/array/*.java -# ops/src/main/java/my/package/math/*.java +# will create a genrule named "gen_sources" that generates source files under +# ops/src/main/java/my/package/**/*.java # -# and then archive those source files in: +# and then archive those source files into # ops/gen_sources.srcjar # def tf_java_op_gen_srcjar(name, gen_tool, - gen_base_package, - ops_libs=[], - ops_libs_pkg="//tensorflow/core", + base_package, + api_def_srcs=[], out_dir="ops/", out_src_dir="src/main/java/", - api_def_srcs=[], visibility=["//tensorflow/java:__pkg__"]): gen_cmds = ["rm -rf $(@D)"] # Always start from fresh when generating source files @@ -48,23 +41,9 @@ def tf_java_op_gen_srcjar(name, ") | cut -d\" \" -f1))") api_def_args_str = ",".join(api_def_args) - gen_tool_deps = [":java_op_gen_lib"] - for ops_lib in ops_libs: - gen_tool_deps.append(ops_libs_pkg + ":" + ops_lib + "_op_lib") - - tf_cc_binary( - name=gen_tool, - srcs=[ - "src/gen/cc/op_gen_main.cc", - ], - copts=tf_copts(), - linkopts=["-lm"], - linkstatic=1, # Faster to link this one-time-use binary dynamically - deps = gen_tool_deps) - - gen_cmds += ["$(location :" + gen_tool + ")" + + gen_cmds += ["$(location " + gen_tool + ")" + " --output_dir=$(@D)/" + out_src_dir + - " --base_package=" + gen_base_package + + " --base_package=" + base_package + " --api_dirs=" + api_def_args_str] # Generate a source archive containing generated code for these ops. |