aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 15:18:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 15:18:06 -0700
commit85631dce2b91585a3d44f7b78db85ed3eba55a48 (patch)
tree41b40de3ecfe56221ecd3abc828049e2705635ea /tensorflow/java/src
parentaf94082f0bb8d665cf20aaa5f86b207e176dd4e0 (diff)
parenta053d7ba69ce9c42f6c854f20aa565407de8c3f7 (diff)
Merge pull request #21616 from karllessard:java-ops-default-type-attrs
PiperOrigin-RevId: 210615110
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h30
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc74
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc42
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h14
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc1
5 files changed, 118 insertions, 43 deletions
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index d9d6f8adc8..d39653ef41 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <string>
#include <utility>
+#include "tensorflow/core/framework/types.h"
+
namespace tensorflow {
namespace java {
@@ -95,6 +97,34 @@ class Type {
static Type IterableOf(const Type& type) {
return Interface("Iterable").add_parameter(type);
}
+ static Type ForDataType(DataType data_type) {
+ switch (data_type) {
+ case DataType::DT_BOOL:
+ return Class("Boolean");
+ case DataType::DT_STRING:
+ return Class("String");
+ case DataType::DT_FLOAT:
+ return Class("Float");
+ case DataType::DT_DOUBLE:
+ return Class("Double");
+ case DataType::DT_UINT8:
+ return Class("UInt8", "org.tensorflow.types");
+ case DataType::DT_INT32:
+ return Class("Integer");
+ case DataType::DT_INT64:
+ return Class("Long");
+ case DataType::DT_RESOURCE:
+ // TODO(karllessard) create a Resource utility class that could be
+ // used to store a resource and its type (passed in a second argument).
+ // For now, we need to force a wildcard and we will unfortunately lose
+ // track of the resource type.
+ // Falling through...
+ default:
+ // Any other datatypes does not have a equivalent in Java and must
+ // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
+ return Wildcard();
+ }
+ }
const Kind& kind() const { return kind_; }
const string& name() const { return name_; }
const string& package() const { return package_; }
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index d5bd99bdd9..5d6387e88e 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <set>
#include <string>
+#include <utility>
#include <vector>
#include "tensorflow/core/framework/op_gen_lib.h"
@@ -100,6 +101,10 @@ void CollectOpDependencies(const OpSpec& op, RenderMode mode,
for (const AttributeSpec& attribute : op.attributes()) {
out->push_back(attribute.var().type());
out->push_back(attribute.jni_type());
+ if (attribute.has_default_value() &&
+ attribute.type().kind() == Type::GENERIC) {
+ out->push_back(Type::ForDataType(attribute.default_value()->type()));
+ }
}
for (const AttributeSpec& optional_attribute : op.optional_attributes()) {
out->push_back(optional_attribute.var().type());
@@ -139,6 +144,60 @@ void WriteSetAttrDirective(const AttributeSpec& attr, bool optional,
}
}
+void RenderSecondaryFactoryMethod(const OpSpec& op, const Type& op_class,
+ std::map<string, Type> default_types,
+ SourceWriter* writer) {
+ // Build the return type for the secondary factory, replacing generic
+ // parameters with their default value if any
+ Type return_type = Type::Class(op_class.name(), op_class.package());
+ for (const Type& parameter : op_class.parameters()) {
+ if (parameter.kind() == Type::GENERIC &&
+ default_types.find(parameter.name()) != default_types.end()) {
+ return_type.add_parameter(default_types.at(parameter.name()));
+ } else {
+ return_type.add_parameter(parameter);
+ }
+ }
+ Method factory = Method::Create("create", return_type);
+ Javadoc factory_doc = Javadoc::Create(
+ "Factory method to create a class to wrap a new " + op_class.name() +
+ " operation to the graph, using "
+ "default output types.");
+ Variable scope =
+ Variable::Create("scope", Type::Class("Scope", "org.tensorflow.op"));
+ AddArgument(scope, "current graph scope", &factory, &factory_doc);
+ std::stringstream factory_statement;
+ factory_statement << "return create(scope";
+ for (const ArgumentSpec& input : op.inputs()) {
+ AddArgument(input.var(), input.description(), &factory, &factory_doc);
+ factory_statement << ", " << input.var().name();
+ }
+ for (const AttributeSpec& attr : op.attributes()) {
+ // Only add attributes that are not types or have no default value to the
+ // signature of the secondary factory
+ factory_statement << ", ";
+ if (attr.type().kind() == Type::GENERIC &&
+ default_types.find(attr.type().name()) != default_types.end()) {
+ factory_statement << default_types.at(attr.type().name()).name()
+ << ".class";
+ } else {
+ AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+ factory_statement << attr.var().name();
+ }
+ }
+ if (!op.optional_attributes().empty()) {
+ Variable options_var = Variable::Varargs("options", Type::Class("Options"));
+ AddArgument(options_var, "carries optional attributes values", &factory,
+ &factory_doc);
+ factory_statement << ", " << options_var.name();
+ }
+ factory_doc.add_tag("return", "a new instance of " + op_class.name());
+
+ writer->BeginMethod(factory, PUBLIC | STATIC, &factory_doc);
+ writer->Append(factory_statement.str().c_str()).Append(");").EndLine();
+ writer->EndMethod();
+}
+
void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
SourceWriter* writer) {
Method factory = Method::Create("create", op_class);
@@ -151,8 +210,17 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
for (const ArgumentSpec& input : op.inputs()) {
AddArgument(input.var(), input.description(), &factory, &factory_doc);
}
+ std::map<string, Type> default_types;
for (const AttributeSpec& attr : op.attributes()) {
AddArgument(attr.var(), attr.description(), &factory, &factory_doc);
+ // If this attribute is a type with a default value, save its value
+ // for passing it implicitly in a secondary factory method
+ if (attr.has_default_value() && attr.type().kind() == Type::GENERIC) {
+ Type default_type = Type::ForDataType(attr.default_value()->type());
+ if (!default_type.wildcard()) {
+ default_types.insert(std::make_pair(attr.type().name(), default_type));
+ }
+ }
}
if (!op.optional_attributes().empty()) {
AddArgument(Variable::Varargs("options", Type::Class("Options")),
@@ -194,6 +262,12 @@ void RenderFactoryMethods(const OpSpec& op, const Type& op_class,
.Append("(opBuilder.build());")
.EndLine();
writer->EndMethod();
+
+ // If this operation has type attributes with a default value, create a
+ // second factory method that infers those values implicitly
+ if (!default_types.empty()) {
+ RenderSecondaryFactoryMethod(op, op_class, default_types, writer);
+ }
}
void RenderConstructor(const OpSpec& op, const Type& op_class,
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
index 941ab2699c..4f5a491d25 100644
--- a/tensorflow/java/src/gen/cc/op_specs.cc
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -96,43 +96,10 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = true;
visited_attrs_.insert(std::make_pair(arg_def.number_attr(), Type::Int()));
}
-
Type type = Type::Wildcard();
if (arg_def.type() != DataType::DT_INVALID) {
- // resolve type from DataType
- switch (arg_def.type()) {
- case DataType::DT_BOOL:
- type = Type::Class("Boolean");
- break;
- case DataType::DT_STRING:
- type = Type::Class("String");
- break;
- case DataType::DT_FLOAT:
- type = Type::Class("Float");
- break;
- case DataType::DT_DOUBLE:
- type = Type::Class("Double");
- break;
- case DataType::DT_UINT8:
- type = Type::Class("UInt8", "org.tensorflow.types");
- break;
- case DataType::DT_INT32:
- type = Type::Class("Integer");
- break;
- case DataType::DT_INT64:
- type = Type::Class("Long");
- break;
- case DataType::DT_RESOURCE:
- // TODO(karllessard) create a Resource utility class that could be
- // used to store a resource and its type (passed in a second argument).
- // For now, we need to force a wildcard and we will unfortunately lose
- // track of the resource type.
- break;
- default:
- // Any other datatypes does not have a equivalent in Java and must
- // remain a wildcard (e.g. DT_COMPLEX64, DT_QINT8, ...)
- break;
- }
+ type = Type::ForDataType(arg_def.type());
+
} else if (!arg_def.type_attr().empty()) {
// resolve type from attribute (if already visited, retrieve its type)
if (IsAttributeVisited(arg_def.type_attr())) {
@@ -337,7 +304,7 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
Type var_type = types.first.kind() == Type::GENERIC
- ? Type::Class("Class").add_parameter(types.first)
+ ? Type::ClassOf(types.first)
: types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
@@ -346,7 +313,8 @@ AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
types.first, types.second, ParseDocumentation(attr_api_def.description()),
- iterable, attr_api_def.has_default_value());
+ iterable,
+ attr_def.has_default_value() ? &attr_def.default_value() : nullptr);
}
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index 30ecb8ce53..4adcfca96a 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -94,18 +94,21 @@ class AttributeSpec {
// jni_type: the type of this attribute in JNI layer (see OperationBuilder)
// description: a description of this attribute, in javadoc
// iterable: true if this attribute is a list
- // has_default_value: true if this attribute has a default value if not set
+ // default_value: default value for this attribute or nullptr if none. Any
+ // value referenced by this pointer must outlive the lifetime
+ // of the AttributeSpec. This is guaranteed if the value is
+ // issued by an OpDef of the global OpRegistry.
AttributeSpec(const string& op_def_name, const Variable& var,
const Type& type, const Type& jni_type,
const string& description, bool iterable,
- bool has_default_value)
+ const AttrValue* default_value)
: op_def_name_(op_def_name),
var_(var),
type_(type),
description_(description),
iterable_(iterable),
jni_type_(jni_type),
- has_default_value_(has_default_value) {}
+ default_value_(default_value) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -113,7 +116,8 @@ class AttributeSpec {
const string& description() const { return description_; }
bool iterable() const { return iterable_; }
const Type& jni_type() const { return jni_type_; }
- bool has_default_value() const { return has_default_value_; }
+ bool has_default_value() const { return default_value_ != nullptr; }
+ const AttrValue* default_value() const { return default_value_; }
private:
const string op_def_name_;
@@ -122,7 +126,7 @@ class AttributeSpec {
const string description_;
const bool iterable_;
const Type jni_type_;
- const bool has_default_value_;
+ const AttrValue* default_value_;
};
class OpSpec {
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index 8e5fba7e32..a71b367691 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -16,7 +16,6 @@ limitations under the License.
#include <string>
#include <algorithm>
#include <list>
-#include <string>
#include "tensorflow/java/src/gen/cc/source_writer.h"