diff options
Diffstat (limited to 'tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java')
-rw-r--r-- | tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java | 296 |
1 files changed, 160 insertions, 136 deletions
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 3524160d87..796d6a62dc 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,6 +15,18 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -23,7 +35,6 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; - import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; - /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor { // generated our code, flag the location of each such class. if (hasRun) { for (Element e : annotated) { - error(e, "The Operator processor has already processed @Operator annotated sources\n" + - "and written out an Ops API. It cannot process additional @Operator sources.\n" + - "One reason this can happen is if other annotation processors generate\n" + - "new @Operator source files."); + error( + e, + "The Operator processor has already processed @Operator annotated sources\n" + + "and written out an Ops API. It cannot process additional @Operator sources.\n" + + "One reason this can happen is if other annotation processors generate\n" + + "new @Operator source files."); } return true; } @@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor { return Collections.singleton("org.tensorflow.op.annotation.Operator"); } - private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); - private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); private static final TypeName T_STRING = ClassName.get(String.class); @@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor { private void write(TypeSpec spec) { try { - JavaFile.builder("org.tensorflow.op", spec) - .skipJavaLangImports(true) - .build() - .writeTo(filer); + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { throw new AssertionError(e); } } private void writeApi(Multimap<String, MethodSpec> groupedMethods) { - Map<String, ClassName> groups = new HashMap<String, ClassName>(); - + Map<String, ClassName> groups = new HashMap<>(); + // Generate a API class for each group collected other than the default one (= empty string) - for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) { + for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) { if (!entry.getKey().isEmpty()) { TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); write(groupClass); @@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor { } private boolean collectOpsMethods( - RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) { + RoundEnvironment roundEnv, + Multimap<String, MethodSpec> groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. if (!(e instanceof TypeElement)) { - error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString()); + error( + e, + "@Operator can only be applied to classes, but this is a %s", + e.getKind().toString()); result = false; continue; } @@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor { } return result; } - - private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { + + private void collectOpMethods( + Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { AnnotationMirror am = getAnnotationMirror(opClass, annotation); String groupName = getAnnotationElementValueAsString("group", am); String methodName = getAnnotationElementValueAsString("name", am); ClassName opClassName = ClassName.get(opClass); if (Strings.isNullOrEmpty(methodName)) { - methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); } - // Build a method for each @Operator found in the class path. There should be one method per operation factory called + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called // "create", which takes in parameter a scope and, optionally, a list of arguments for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { - if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); groupedMethods.put(groupName, method); } } } - private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName) - .addModifiers(Modifier.PUBLIC) - .returns(TypeName.get(factoryMethod.getReturnType())) - .varargs(factoryMethod.isVarArgs()) - .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); - for (TypeParameterElement tp: factoryMethod.getTypeParameters()) { + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); builder.addTypeVariable(tvn); } - for (TypeMirror thrownType: factoryMethod.getThrownTypes()) { + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { builder.addException(TypeName.get(thrownType)); } StringBuilder call = new StringBuilder("return $T.create(scope"); @@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor { call.append(")"); builder.addStatement(call.toString(), opClassName); return builder.build(); - } - + } + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { StringBuilder javadoc = new StringBuilder(); - javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n"); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); - // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the // 'scope' parameter that is implicitly passed by this API Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); boolean firstParam = true; @@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor { } else { javadoc.append(tag).append('\n'); } - } + } javadoc.append("@see {@link ").append(opClassName).append("}\n"); return javadoc.toString(); } - + private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope"); - + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + TypeSpec.Builder builder = TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + - "@see {@link $T}\n", group, T_GRAPH, T_OPS) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); builder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); return builder.build(); } - private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { + private static TypeSpec buildTopClass( + Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addModifiers(Modifier.PRIVATE) - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope", T_SCOPE); + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); } TypeSpec.Builder opsBuilder = TypeSpec.classBuilder("Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" + - "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" + - "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + - "try (Graph g = new Graph()) {\n" + - " Ops ops = new Ops(g);\n" + - " // Operations are typed classes with convenience\n" + - " // builders in Ops.\n" + - " Constant three = ops.constant(3);\n" + - " // Single-result operations implement the Operand\n" + - " // interface, so this works too.\n" + - " Operand four = ops.constant(4);\n" + - " // Most builders are found within a group, and accept\n" + - " // Operand types as operands\n" + - " Operand nine = ops.math().add(four, ops.constant(5));\n" + - " // Multi-result operations however offer methods to\n" + - " // select a particular result for use.\n" + - " Operand result = \n" + - " ops.math().add(ops.array().unique(s, a).y(), b);\n" + - " // Optional attributes\n" + - " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + - " // Naming operators\n" + - " ops.withName(“foo”).constant(5); // name “foo”\n" + - " // Names can exist in a hierarchy\n" + - " Ops sub = ops.withSubScope(“sub”);\n" + - " sub.withName(“bar”).constant(4); // “sub/bar”\n" + - "}\n" + - "}</pre>\n", T_GRAPH, T_OPERATOR) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n<p>\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + + "try (Graph g = new Graph()) {\n" + + " Ops ops = new Ops(g);\n" + + " // Operations are typed classes with convenience\n" + + " // builders in Ops.\n" + + " Constant three = ops.constant(3);\n" + + " // Single-result operations implement the Operand\n" + + " // interface, so this works too.\n" + + " Operand four = ops.constant(4);\n" + + " // Most builders are found within a group, and accept\n" + + " // Operand types as operands\n" + + " Operand nine = ops.math().add(four, ops.constant(5));\n" + + " // Multi-result operations however offer methods to\n" + + " // select a particular result for use.\n" + + " Operand result = \n" + + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " // Optional attributes\n" + + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " // Naming operators\n" + + " ops.withName(“foo”).constant(5); // name “foo”\n" + + " // Names can exist in a hierarchy\n" + + " Ops sub = ops.withSubScope(“sub”);\n" + + " sub.withName(“bar”).constant(4); // “sub/bar”\n" + + "}\n" + + "}</pre>\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withSubScope") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "childScopeName") - .returns(T_OPS) - .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) - .addJavadoc( - "Returns an API that adds operations to the graph with the provided name prefix.\n\n" + - "@see {@link $T#withSubScope(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withName") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "opName") - .returns(T_OPS) - .addStatement("return new Ops(scope.withName(opName))") - .addJavadoc( - "Returns an API that uses the provided name for an op.\n\n" + - "@see {@link $T#withName(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); opsBuilder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); opsBuilder.addMethod( MethodSpec.methodBuilder("scope") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(T_SCOPE) - .addStatement("return scope") - .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { opsBuilder.addField( FieldSpec.builder(entry.getValue(), entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .build()); - + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + opsBuilder.addMethod( MethodSpec.methodBuilder(entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(entry.getValue()) - .addStatement("return $L", entry.getKey()) - .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); } opsBuilder.addMethod( MethodSpec.methodBuilder("create") - .addModifiers(Modifier.PUBLIC, Modifier.STATIC) - .addParameter(T_GRAPH, "graph") - .returns(T_OPS) - .addStatement("return new Ops(new $T(graph))", T_SCOPE) - .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); return opsBuilder.build(); } @@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor { return am; } } - throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element " - + element.getSimpleName()); + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); } - + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { - for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) { + for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : + am.getElementValues().entrySet()) { if (entry.getKey().getSimpleName().contentEquals(elementName)) { return entry.getValue().getValue().toString(); } |