diff options
author | karl@kubx.ca <karl@kubx.ca> | 2018-05-13 23:46:37 -0400 |
---|---|---|
committer | karl@kubx.ca <karl@kubx.ca> | 2018-06-12 23:21:38 -0400 |
commit | 60552388401b3e70a21d4c01d3d374c9d85aea2b (patch) | |
tree | abd9805d5f5fbac4e67f1764013a43ba8733a8d2 /tensorflow/java/src | |
parent | 1aea422ca658d0ac2121245d31c3aa78a73c0efb (diff) |
Complete operator processor for generating Ops API classes
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r-- | tensorflow/java/src/gen/cc/op_generator.cc | 40 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.cc | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/cc/op_specs.h | 2 | ||||
-rw-r--r-- | tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java | 347 |
4 files changed, 334 insertions, 57 deletions
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 9b171f66ec..2df69ee299 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -35,22 +35,21 @@ namespace tensorflow { namespace java { namespace { -const char* kLicense = - "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" - "\n" - "Licensed under the Apache License, Version 2.0 (the \"License\");\n" - "you may not use this file except in compliance with the License.\n" - "You may obtain a copy of the License at\n" - "\n" - " http://www.apache.org/licenses/LICENSE-2.0\n" - "\n" - "Unless required by applicable law or agreed to in writing, software\n" - "distributed under the License is distributed on an \"AS IS\" BASIS,\n" - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "See the License for the specific language governing permissions and\n" - "limitations under the License.\n" - "=======================================================================*/" - "\n"; +constexpr const char kLicense[] = + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/\n"; // There is three different modes to render an op class, depending on the // number and type of outputs it has: @@ -391,9 +390,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } if (!op.hidden()) { // expose the op in the Ops Graph API only if it is visible - op_class.add_annotation( - Annotation::Create("Operator", "org.tensorflow.op.annotation") - .attributes("group = \"" + endpoint.package() + "\"")); + Annotation oper_annot = + Annotation::Create("Operator", "org.tensorflow.op.annotation"); + if (endpoint.package() != kDefaultEndpointPackage) { + oper_annot.attributes("group = \"" + endpoint.package() + "\""); + } + op_class.add_annotation(oper_annot); } // create op class file const string op_dir_name = io::JoinPath( diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 4bcfc7fe01..f0e4bcca82 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -376,7 +376,7 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, package = name_tokens.at(0); name = name_tokens.at(1); } else { - package = "core"; // generate unclassified ops in the 'core' package + package = kDefaultEndpointPackage; name = name_tokens.at(0); } return EndpointSpec(package, diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 034cf636ed..3b53c730df 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -27,6 +27,8 @@ limitations under the License. namespace tensorflow { namespace java { +constexpr const char kDefaultEndpointPackage[] = "core"; + class EndpointSpec { public: // A specification for an operation endpoint diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 11fda4fc22..0f59754004 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -16,20 +16,48 @@ limitations under the License. package org.tensorflow.processor; import java.io.IOException; -import java.io.PrintWriter; +import java.util.Collection; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; import javax.annotation.processing.ProcessingEnvironment; import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; +import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.element.TypeParameterElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeVariable; +import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.AnnotationSpec; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; + /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -55,6 +83,7 @@ public final class OperatorProcessor extends AbstractProcessor { super.init(processingEnv); messager = processingEnv.getMessager(); filer = processingEnv.getFiler(); + elements = processingEnv.getElementUtils(); } @Override @@ -87,29 +116,28 @@ public final class OperatorProcessor extends AbstractProcessor { // generated our code, flag the location of each such class. if (hasRun) { for (Element e : annotated) { - error( - e, - "The Operator processor has already processed @Operator annotated sources\n" - + "and written out an Ops API. It cannot process additional @Operator sources.\n" - + "One reason this can happen is if other annotation processors generate\n" - + "new @Operator source files."); + error(e, "The Operator processor has already processed @Operator annotated sources\n" + + "and written out an Ops API. It cannot process additional @Operator sources.\n" + + "One reason this can happen is if other annotation processors generate\n" + + "new @Operator source files."); } return true; } // Collect all classes tagged with our annotation. - Set<TypeElement> opClasses = new HashSet<TypeElement>(); - if (!collectOpClasses(roundEnv, opClasses, annotation)) { + Multimap<String, MethodSpec> groupedMethods = HashMultimap.create(); + if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { return true; } // Nothing to do when there are no tagged classes. - if (opClasses.isEmpty()) { + if (groupedMethods.isEmpty()) { return true; } - // TODO:(kbsriram) validate operator classes and generate Op API. - writeApi(); + // Validate operator classes and generate Op API. + writeApi(groupedMethods); + hasRun = true; return true; } @@ -119,46 +147,291 @@ public final class OperatorProcessor extends AbstractProcessor { return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE)); } - private void writeApi() { - // Generate an empty class for now and get the build working correctly. This will be changed to - // generate the actual API once we've done with build-related changes. - // TODO:(kbsriram) - try (PrintWriter writer = - new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) { - writer.println(String.format("package %s;", OP_PACKAGE)); - writer.println("public class Ops{}"); + private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); + private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); + private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); + private static final TypeName T_STRING = ClassName.get(String.class); + private static final String OP_PACKAGE = "org.tensorflow.op"; + + private Filer filer; + private Messager messager; + private Elements elements; + private boolean hasRun = false; + + private void error(Element e, String message, Object... args) { + if (args != null && args.length > 0) { + message = String.format(message, args); + } + messager.printMessage(Kind.ERROR, message, e); + } + + private void write(TypeSpec spec) { + try { + JavaFile.builder("org.tensorflow.op", spec) + .skipJavaLangImports(true) + .build() + .writeTo(filer); } catch (IOException e) { - error(null, "Unexpected failure generating API: %s", e.getMessage()); + throw new AssertionError(e); } } - private boolean collectOpClasses( - RoundEnvironment roundEnv, Set<TypeElement> opClasses, TypeElement annotation) { + private void writeApi(Multimap<String, MethodSpec> groupedMethods) { + Map<String, ClassName> groups = new HashMap<String, ClassName>(); + + // Generate a API class for each group collected other than the default one (= empty string) + for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) { + if (!entry.getKey().isEmpty()) { + TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); + write(groupClass); + groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); + } + } + // Generate the top API class, adding any methods added to the default group + TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); + write(topClass); + } + + private boolean collectOpsMethods( + RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. if (!(e instanceof TypeElement)) { - error( - e, - "@Operator can only be applied to classes, but this is a %s", - e.getKind().toString()); + error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString()); result = false; continue; } - opClasses.add((TypeElement) e); + collectOpMethods(groupedMethods, (TypeElement) e, annotation); } return result; } + + private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { + AnnotationMirror am = getAnnotationMirror(opClass, annotation); + String groupName = getAnnotationElementValueAsString("group", am); + String methodName = getAnnotationElementValueAsString("name", am); + if (Strings.isNullOrEmpty(methodName)) { + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, ClassName.get(opClass).simpleName()); + } + // Build a method for each @Operator found in the class path. There should be one method per operation factory called + // "create", which takes in parameter a scope and, optionally, a list of arguments + for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { + if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) { + MethodSpec method = buildOpMethod(methodName, opClass, opMethod); + groupedMethods.put(groupName, method); + } + } + } - private void error(Element e, String message, Object... args) { - if (args != null && args.length > 0) { - message = String.format(message, args); + private MethodSpec buildOpMethod(String methodName, TypeElement opClass, ExecutableElement factoryMethod) { + boolean deprecated = opClass.getAnnotation(Deprecated.class) != null; + ClassName opClassName = ClassName.get(opClass); + MethodSpec.Builder builder = + MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod, deprecated)); + + for (TypeParameterElement tp: factoryMethod.getTypeParameters()) { + TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); + builder.addTypeVariable(tvn); } - messager.printMessage(Kind.ERROR, message, e); + for (TypeMirror thrownType: factoryMethod.getThrownTypes()) { + builder.addException(TypeName.get(thrownType)); + } + if (deprecated) { + builder.addAnnotation(AnnotationSpec.builder(Deprecated.class).build()); + } + StringBuilder call = new StringBuilder("return $T.create(scope"); + boolean first = true; + for (VariableElement param : factoryMethod.getParameters()) { + ParameterSpec p = ParameterSpec.get(param); + if (first) { + first = false; + continue; + } + call.append(", "); + call.append(p.name); + builder.addParameter(p); + } + call.append(")"); + builder.addStatement(call.toString(), opClassName); + return builder.build(); + } + + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod, boolean deprecated) { + StringBuilder javadoc = new StringBuilder(); + javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n"); + + // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the + // 'scope' parameter that is implicitly passed by this API + Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); + boolean firstParam = true; + + while (tagMatcher.find()) { + String tag = tagMatcher.group(); + if (tag.startsWith("@param") && firstParam) { + firstParam = false; + } else { + javadoc.append(tag).append('\n'); + } + } + if (deprecated) { + javadoc.append("@deprecated\n"); + } + javadoc.append("@see {@link ").append(opClassName).append("}\n"); + + return javadoc.toString(); } + + private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + + TypeSpec.Builder builder = + TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", group, T_GRAPH, T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); - private Filer filer; - private Messager messager; - private boolean hasRun = false; - private static final String OP_PACKAGE = "org.tensorflow.op"; + builder.addField( + FieldSpec.builder(T_SCOPE, "scope") + .addModifiers(Modifier.PRIVATE, Modifier.FINAL) + .build()); + + return builder.build(); + } + + private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); + + for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); + } + + TypeSpec.Builder opsBuilder = + TypeSpec.classBuilder("Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" + + "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + + "try (Graph g = new Graph()) {\n" + + " Ops ops = new Ops(g);\n" + + " // Operations are typed classes with convenience\n" + + " // builders in Ops.\n" + + " Constant three = ops.constant(3);\n" + + " // Single-result operations implement the Input\n" + + " // interface, so this works too.\n" + + " Input four = ops.constant(4);\n" + + " // Most builders are found within a group, and accept\n" + + " // Input types as operands\n" + + " Input nine = ops.math().add(four, ops.constant(5));\n" + + " // Multi-result operations however offer methods to\n" + + " // select a particular result for use.\n" + + " Input result = \n" + + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " // Optional attributes\n" + + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " // Naming operators\n" + + " ops.withOpName(“foo”).constant(5); // name “foo”\n" + + " // Names can exist in a hierarchy\n" + + " Ops sub = ops.withSubscope(“sub”);\n" + + " sub.withOpName(“bar”).constant(4); // “sub/bar”\n" + + "}\n" + + "}</pre>\n", T_GRAPH, T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withSubScope") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n\n" + + "@see {@link $T#withSubScope(String)}\n", T_SCOPE) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withName") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", T_SCOPE) + .build()); + + opsBuilder.addField( + FieldSpec.builder(T_SCOPE, "scope") + .addModifiers(Modifier.PRIVATE, Modifier.FINAL) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("scope") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); + + for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + opsBuilder.addField( + FieldSpec.builder(entry.getValue(), entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder(entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); + } + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); + + return opsBuilder.build(); + } + + private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { + for (AnnotationMirror am : element.getAnnotationMirrors()) { + if (am.getAnnotationType().asElement().equals(annotation)) { + return am; + } + } + throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element " + + element.getSimpleName()); + } + + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { + for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) { + if (entry.getKey().getSimpleName().contentEquals(elementName)) { + return entry.getValue().getValue().toString(); + } + } + return ""; + } } |