aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-06-29 13:50:37 -0700
committerGravatar Michael Case <mikecase@google.com>2018-06-29 13:50:37 -0700
commit7be4245d629510ed3d1c2edd7a2598167017f33b (patch)
tree6585d143160249ae282044790d4589145a79efa2 /tensorflow/java
parent01c36c3d7b3e230c865e71d67e138a8dc765e7a6 (diff)
parent79dab9ced650d69bdf3f312bd902bd52de5bdad8 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc29
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java296
2 files changed, 175 insertions, 150 deletions
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index 2df69ee299..d5bd99bdd9 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -36,20 +36,21 @@ namespace java {
namespace {
constexpr const char kLicense[] =
- "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
- "\n"
- "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
- "you may not use this file except in compliance with the License.\n"
- "You may obtain a copy of the License at\n"
- "\n"
- " http://www.apache.org/licenses/LICENSE-2.0\n"
- "\n"
- "Unless required by applicable law or agreed to in writing, software\n"
- "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
- "See the License for the specific language governing permissions and\n"
- "limitations under the License.\n"
- "=======================================================================*/\n";
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ "\n"
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ "you may not use this file except in compliance with the License.\n"
+ "You may obtain a copy of the License at\n"
+ "\n"
+ " http://www.apache.org/licenses/LICENSE-2.0\n"
+ "\n"
+ "Unless required by applicable law or agreed to in writing, software\n"
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ "See the License for the specific language governing permissions and\n"
+ "limitations under the License.\n"
+ "=======================================================================*/"
+ "\n";
// There is three different modes to render an op class, depending on the
// number and type of outputs it has:
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 3524160d87..796d6a62dc 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -15,6 +15,18 @@ limitations under the License.
package org.tensorflow.processor;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.FieldSpec;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeSpec;
+import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
@@ -23,7 +35,6 @@ import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
-
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
@@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter;
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;
-import com.google.common.base.CaseFormat;
-import com.google.common.base.Strings;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.Multimap;
-import com.squareup.javapoet.ClassName;
-import com.squareup.javapoet.FieldSpec;
-import com.squareup.javapoet.JavaFile;
-import com.squareup.javapoet.MethodSpec;
-import com.squareup.javapoet.ParameterSpec;
-import com.squareup.javapoet.TypeName;
-import com.squareup.javapoet.TypeSpec;
-import com.squareup.javapoet.TypeVariableName;
-
/**
* A compile-time Processor that aggregates classes annotated with {@link
* org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
@@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor {
// generated our code, flag the location of each such class.
if (hasRun) {
for (Element e : annotated) {
- error(e, "The Operator processor has already processed @Operator annotated sources\n" +
- "and written out an Ops API. It cannot process additional @Operator sources.\n" +
- "One reason this can happen is if other annotation processors generate\n" +
- "new @Operator source files.");
+ error(
+ e,
+ "The Operator processor has already processed @Operator annotated sources\n"
+ + "and written out an Ops API. It cannot process additional @Operator sources.\n"
+ + "One reason this can happen is if other annotation processors generate\n"
+ + "new @Operator source files.");
}
return true;
}
@@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor {
return Collections.singleton("org.tensorflow.op.annotation.Operator");
}
- private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
+ private static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
- private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator");
+ private static final TypeName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
private static final TypeName T_STRING = ClassName.get(String.class);
@@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor {
private void write(TypeSpec spec) {
try {
- JavaFile.builder("org.tensorflow.op", spec)
- .skipJavaLangImports(true)
- .build()
- .writeTo(filer);
+ JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
} catch (IOException e) {
throw new AssertionError(e);
}
}
private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
- Map<String, ClassName> groups = new HashMap<String, ClassName>();
-
+ Map<String, ClassName> groups = new HashMap<>();
+
// Generate a API class for each group collected other than the default one (= empty string)
- for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) {
+ for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
if (!entry.getKey().isEmpty()) {
TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
write(groupClass);
@@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor {
}
private boolean collectOpsMethods(
- RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) {
+ RoundEnvironment roundEnv,
+ Multimap<String, MethodSpec> groupedMethods,
+ TypeElement annotation) {
boolean result = true;
for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
// @Operator can only apply to types, so e must be a TypeElement.
if (!(e instanceof TypeElement)) {
- error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString());
+ error(
+ e,
+ "@Operator can only be applied to classes, but this is a %s",
+ e.getKind().toString());
result = false;
continue;
}
@@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor {
}
return result;
}
-
- private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
+
+ private void collectOpMethods(
+ Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
AnnotationMirror am = getAnnotationMirror(opClass, annotation);
String groupName = getAnnotationElementValueAsString("group", am);
String methodName = getAnnotationElementValueAsString("name", am);
ClassName opClassName = ClassName.get(opClass);
if (Strings.isNullOrEmpty(methodName)) {
- methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
+ methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
}
- // Build a method for each @Operator found in the class path. There should be one method per operation factory called
+ // Build a method for each @Operator found in the class path. There should be one method per
+ // operation factory called
// "create", which takes in parameter a scope and, optionally, a list of arguments
for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
- if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) {
+ if (opMethod.getModifiers().contains(Modifier.STATIC)
+ && opMethod.getSimpleName().contentEquals("create")) {
MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
groupedMethods.put(groupName, method);
}
}
}
- private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
+ private MethodSpec buildOpMethod(
+ String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
MethodSpec.Builder builder =
MethodSpec.methodBuilder(methodName)
- .addModifiers(Modifier.PUBLIC)
- .returns(TypeName.get(factoryMethod.getReturnType()))
- .varargs(factoryMethod.isVarArgs())
- .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(factoryMethod.getReturnType()))
+ .varargs(factoryMethod.isVarArgs())
+ .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
- for (TypeParameterElement tp: factoryMethod.getTypeParameters()) {
+ for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
builder.addTypeVariable(tvn);
}
- for (TypeMirror thrownType: factoryMethod.getThrownTypes()) {
+ for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
builder.addException(TypeName.get(thrownType));
}
StringBuilder call = new StringBuilder("return $T.create(scope");
@@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor {
call.append(")");
builder.addStatement(call.toString(), opClassName);
return builder.build();
- }
-
+ }
+
private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
StringBuilder javadoc = new StringBuilder();
- javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n");
+ javadoc
+ .append("Adds an {@link ")
+ .append(opClassName.simpleName())
+ .append("} operation to the graph\n\n");
- // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the
+ // Add all javadoc tags found in the operator factory method but the first one, which should be
+ // in all cases the
// 'scope' parameter that is implicitly passed by this API
Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
boolean firstParam = true;
@@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor {
} else {
javadoc.append(tag).append('\n');
}
- }
+ }
javadoc.append("@see {@link ").append(opClassName).append("}\n");
return javadoc.toString();
}
-
+
private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope");
-
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope");
+
TypeSpec.Builder builder =
TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" +
- "@see {@link $T}\n", group, T_GRAPH, T_OPS)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
+ + "@see {@link $T}\n",
+ group,
+ T_GRAPH,
+ T_OPS)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
builder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
return builder.build();
}
- private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
+ private static TypeSpec buildTopClass(
+ Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addModifiers(Modifier.PRIVATE)
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope", T_SCOPE);
+ .addModifiers(Modifier.PRIVATE)
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope", T_SCOPE);
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
}
TypeSpec.Builder opsBuilder =
TypeSpec.classBuilder("Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" +
- "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" +
- "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" +
- "try (Graph g = new Graph()) {\n" +
- " Ops ops = new Ops(g);\n" +
- " // Operations are typed classes with convenience\n" +
- " // builders in Ops.\n" +
- " Constant three = ops.constant(3);\n" +
- " // Single-result operations implement the Operand\n" +
- " // interface, so this works too.\n" +
- " Operand four = ops.constant(4);\n" +
- " // Most builders are found within a group, and accept\n" +
- " // Operand types as operands\n" +
- " Operand nine = ops.math().add(four, ops.constant(5));\n" +
- " // Multi-result operations however offer methods to\n" +
- " // select a particular result for use.\n" +
- " Operand result = \n" +
- " ops.math().add(ops.array().unique(s, a).y(), b);\n" +
- " // Optional attributes\n" +
- " ops.math().matMul(a, b, MatMul.transposeA(true));\n" +
- " // Naming operators\n" +
- " ops.withName(“foo”).constant(5); // name “foo”\n" +
- " // Names can exist in a hierarchy\n" +
- " Ops sub = ops.withSubScope(“sub”);\n" +
- " sub.withName(“bar”).constant(4); // “sub/bar”\n" +
- "}\n" +
- "}</pre>\n", T_GRAPH, T_OPERATOR)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for building a {@link $T} with operation wrappers\n<p>\n"
+ + "Any operation wrapper found in the classpath properly annotated as an"
+ + "{@link $T @Operator} is exposed\n"
+ + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
+ + "try (Graph g = new Graph()) {\n"
+ + " Ops ops = new Ops(g);\n"
+ + " // Operations are typed classes with convenience\n"
+ + " // builders in Ops.\n"
+ + " Constant three = ops.constant(3);\n"
+ + " // Single-result operations implement the Operand\n"
+ + " // interface, so this works too.\n"
+ + " Operand four = ops.constant(4);\n"
+ + " // Most builders are found within a group, and accept\n"
+ + " // Operand types as operands\n"
+ + " Operand nine = ops.math().add(four, ops.constant(5));\n"
+ + " // Multi-result operations however offer methods to\n"
+ + " // select a particular result for use.\n"
+ + " Operand result = \n"
+ + " ops.math().add(ops.array().unique(s, a).y(), b);\n"
+ + " // Optional attributes\n"
+ + " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+ + " // Naming operators\n"
+ + " ops.withName(“foo”).constant(5); // name “foo”\n"
+ + " // Names can exist in a hierarchy\n"
+ + " Ops sub = ops.withSubScope(“sub”);\n"
+ + " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ + "}\n"
+ + "}</pre>\n",
+ T_GRAPH,
+ T_OPERATOR)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withSubScope")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "childScopeName")
- .returns(T_OPS)
- .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
- .addJavadoc(
- "Returns an API that adds operations to the graph with the provided name prefix.\n\n" +
- "@see {@link $T#withSubScope(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "childScopeName")
+ .returns(T_OPS)
+ .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
+ .addJavadoc(
+ "Returns an API that adds operations to the graph with the provided name prefix.\n"
+ + "\n@see {@link $T#withSubScope(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withName")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "opName")
- .returns(T_OPS)
- .addStatement("return new Ops(scope.withName(opName))")
- .addJavadoc(
- "Returns an API that uses the provided name for an op.\n\n" +
- "@see {@link $T#withName(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "opName")
+ .returns(T_OPS)
+ .addStatement("return new Ops(scope.withName(opName))")
+ .addJavadoc(
+ "Returns an API that uses the provided name for an op.\n\n"
+ + "@see {@link $T#withName(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("scope")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(T_SCOPE)
- .addStatement("return scope")
- .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(T_SCOPE)
+ .addStatement("return scope")
+ .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
+ .build());
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
opsBuilder.addField(
FieldSpec.builder(entry.getValue(), entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .build());
-
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+
opsBuilder.addMethod(
MethodSpec.methodBuilder(entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(entry.getValue())
- .addStatement("return $L", entry.getKey())
- .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(entry.getValue())
+ .addStatement("return $L", entry.getKey())
+ .addJavadoc(
+ "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
+ .build());
}
opsBuilder.addMethod(
MethodSpec.methodBuilder("create")
- .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
- .addParameter(T_GRAPH, "graph")
- .returns(T_OPS)
- .addStatement("return new Ops(new $T(graph))", T_SCOPE)
- .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addParameter(T_GRAPH, "graph")
+ .returns(T_OPS)
+ .addStatement("return new Ops(new $T(graph))", T_SCOPE)
+ .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
+ .build());
return opsBuilder.build();
}
@@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor {
return am;
}
}
- throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element "
- + element.getSimpleName());
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + element.getSimpleName());
}
-
+
private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
- for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) {
+ for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
if (entry.getKey().getSimpleName().contentEquals(elementName)) {
return entry.getValue().getValue().toString();
}