diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 15:18:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-28 15:18:06 -0700 |
commit | 85631dce2b91585a3d44f7b78db85ed3eba55a48 (patch) | |
tree | 41b40de3ecfe56221ecd3abc828049e2705635ea /tensorflow/java/src | |
parent | af94082f0bb8d665cf20aaa5f86b207e176dd4e0 (diff) | |
parent | a053d7ba69ce9c42f6c854f20aa565407de8c3f7 (diff) |
Merge pull request #21616 from karllessard:java-ops-default-type-attrs
PiperOrigin-RevId: 210615110
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r-- | tensorflow/java/src/gen/cc/java_defs.h | 30 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.cc | 74 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.cc | 42 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.h | 14 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/source_writer.cc | 1 |
5 files changed, 118 insertions, 43 deletions
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index d9d6f8adc8..d39653ef41 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -21,6 +21,8 @@ limitations under the License. #include <string> #include <utility> +#include "tensorflow/core/framework/types.h" + namespace tensorflow { namespace java { @@ -95,6 +97,34 @@ class Type { static Type IterableOf(const Type& type) { return Interface("Iterable").add_parameter(type); } + static Type ForDataType(DataType data_type) { + switch (data_type) { + case DataType::DT_BOOL: + return Class("Boolean"); + case DataType::DT_STRING: + return Class("String"); + case DataType::DT_FLOAT: + return Class("Float"); + case DataType::DT_DOUBLE: + return Class("Double"); + case DataType::DT_UINT8: + return Class("UInt8", "org.tensorflow.types"); + case DataType::DT_INT32: + return Class("Integer"); + case DataType::DT_INT64: + return Class("Long"); + case DataType::DT_RESOURCE: + // TODO(karllessard) create a Resource utility class that could be + // used to store a resource and its type (passed in a second argument). + // For now, we need to force a wildcard and we will unfortunately lose + // track of the resource type. + // Falling through... + default: + // Any other datatypes does not have a equivalent in Java and must + // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...) + return Wildcard(); + } + } const Kind& kind() const { return kind_; } const string& name() const { return name_; } const string& package() const { return package_; } diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index d5bd99bdd9..5d6387e88e 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <set> #include <string> +#include <utility> #include <vector> #include "tensorflow/core/framework/op_gen_lib.h" @@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode, for (const AttributeSpec& attribute : op.attributes()) { out->push_back(attribute.var().type()); out->push_back(attribute.jni_type()); + if (attribute.has_default_value() && + attribute.type().kind() == Type::GENERIC) { + out->push_back(Type::ForDataType(attribute.default_value()->type())); + } } for (const AttributeSpec& optional_attribute : op.optional_attributes()) { out->push_back(optional_attribute.var().type()); @@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional, } } +void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class, + std::map<string, Type> default_types, + SourceWriter* writer) { + // Build the return type for the secondary factory, replacing generic + // parameters with their default value if any + Type return_type = Type::Class(op_class.name(), op_class.package()); + for (const Type& parameter : op_class.parameters()) { + if (parameter.kind() == Type::GENERIC && + default_types.find(parameter.name()) != default_types.end()) { + return_type.add_parameter(default_types.at(parameter.name())); + } else { + return_type.add_parameter(parameter); + } + } + Method factory = Method::Create("create", return_type); + Javadoc factory_doc = Javadoc::Create( + "Factory method to create a class to wrap a new " + op_class.name() + + " operation to the graph, using " + "default output types."); + Variable scope = + Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op")); + AddArgument(scope, "current graph scope", &factory, &factory_doc); + std::stringstream factory_statement; + factory_statement << "return create(scope"; + for (const ArgumentSpec& input : op.inputs()) { + AddArgument(input.var(), input.description(), &factory, &factory_doc); + factory_statement << ", " << input.var().name(); + } + for (const AttributeSpec& attr : op.attributes()) { + // Only add attributes that are not types or have no default value to the + // signature of the secondary factory + factory_statement << ", "; + if (attr.type().kind() == Type::GENERIC && + default_types.find(attr.type().name()) != default_types.end()) { + factory_statement << default_types.at(attr.type().name()).name() + << ".class"; + } else { + AddArgument(attr.var(), attr.description(), &factory, &factory_doc); + factory_statement << attr.var().name(); + } + } + if (!op.optional_attributes().empty()) { + Variable options_var = Variable::Varargs("options", Type::Class("Options")); + AddArgument(options_var, "carries optional attributes values", &factory, + &factory_doc); + factory_statement << ", " << options_var.name(); + } + factory_doc.add_tag("return", "a new instance of " + op_class.name()); + + writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc); + writer->Append(factory_statement.str().c_str()).Append(");").EndLine(); + writer->EndMethod(); +} + void RenderFactoryMethods(const OpSpec& op, const Type& op_class, SourceWriter* writer) { Method factory = Method::Create("create", op_class); @@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, for (const ArgumentSpec& input : op.inputs()) { AddArgument(input.var(), input.description(), &factory, &factory_doc); } + std::map<string, Type> default_types; for (const AttributeSpec& attr : op.attributes()) { AddArgument(attr.var(), attr.description(), &factory, &factory_doc); + // If this attribute is a type with a default value, save its value + // for passing it implicitly in a secondary factory method + if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) { + Type default_type = Type::ForDataType(attr.default_value()->type()); + if (!default_type.wildcard()) { + default_types.insert(std::make_pair(attr.type().name(), default_type)); + } + } } if (!op.optional_attributes().empty()) { AddArgument(Variable::Varargs("options", Type::Class("Options")), @@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class, .Append("(opBuilder.build());") .EndLine(); writer->EndMethod(); + + // If this operation has type attributes with a default value, create a + // second factory method that infers those values implicitly + if (!default_types.empty()) { + RenderSecondaryFactoryMethod(op, op_class, default_types, writer); + } } void RenderConstructor(const OpSpec& op, const Type& op_class, diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 941ab2699c..4f5a491d25 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) { *iterable_out = true; visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int())); } - Type type = Type::Wildcard(); if (arg_def.type() != DataType::DT_INVALID) { - // resolve type from DataType - switch (arg_def.type()) { - case DataType::DT_BOOL: - type = Type::Class("Boolean"); - break; - case DataType::DT_STRING: - type = Type::Class("String"); - break; - case DataType::DT_FLOAT: - type = Type::Class("Float"); - break; - case DataType::DT_DOUBLE: - type = Type::Class("Double"); - break; - case DataType::DT_UINT8: - type = Type::Class("UInt8", "org.tensorflow.types"); - break; - case DataType::DT_INT32: - type = Type::Class("Integer"); - break; - case DataType::DT_INT64: - type = Type::Class("Long"); - break; - case DataType::DT_RESOURCE: - // TODO(karllessard) create a Resource utility class that could be - // used to store a resource and its type (passed in a second argument). - // For now, we need to force a wildcard and we will unfortunately lose - // track of the resource type. - break; - default: - // Any other datatypes does not have a equivalent in Java and must - // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...) - break; - } + type = Type::ForDataType(arg_def.type()); + } else if (!arg_def.type_attr().empty()) { // resolve type from attribute (if already visited, retrieve its type) if (IsAttributeVisited(arg_def.type_attr())) { @@ -337,7 +304,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, bool iterable = false; std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable); Type var_type = types.first.kind() == Type::GENERIC - ? Type::Class("Class").add_parameter(types.first) + ? Type::ClassOf(types.first) : types.first; if (iterable) { var_type = Type::ListOf(var_type); @@ -346,7 +313,8 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), types.first, types.second, ParseDocumentation(attr_api_def.description()), - iterable, attr_api_def.has_default_value()); + iterable, + attr_def.has_default_value() ? &attr_def.default_value() : nullptr); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 30ecb8ce53..4adcfca96a 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -94,18 +94,21 @@ class AttributeSpec { // jni_type: the type of this attribute in JNI layer (see OperationBuilder) // description: a description of this attribute, in javadoc // iterable: true if this attribute is a list - // has_default_value: true if this attribute has a default value if not set + // default_value: default value for this attribute or nullptr if none. Any + // value referenced by this pointer must outlive the lifetime + // of the AttributeSpec. This is guaranteed if the value is + // issued by an OpDef of the global OpRegistry. AttributeSpec(const string& op_def_name, const Variable& var, const Type& type, const Type& jni_type, const string& description, bool iterable, - bool has_default_value) + const AttrValue* default_value) : op_def_name_(op_def_name), var_(var), type_(type), description_(description), iterable_(iterable), jni_type_(jni_type), - has_default_value_(has_default_value) {} + default_value_(default_value) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -113,7 +116,8 @@ class AttributeSpec { const string& description() const { return description_; } bool iterable() const { return iterable_; } const Type& jni_type() const { return jni_type_; } - bool has_default_value() const { return has_default_value_; } + bool has_default_value() const { return default_value_ != nullptr; } + const AttrValue* default_value() const { return default_value_; } private: const string op_def_name_; @@ -122,7 +126,7 @@ class AttributeSpec { const string description_; const bool iterable_; const Type jni_type_; - const bool has_default_value_; + const AttrValue* default_value_; }; class OpSpec { diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 8e5fba7e32..a71b367691 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -16,7 +16,6 @@ limitations under the License. #include <string> #include <algorithm> #include <list> -#include <string> #include "tensorflow/java/src/gen/cc/source_writer.h" |