aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-28 19:13:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 19:16:41 -0700
commit1e7b0e4ad6d0f57f3241fe0b80a65f2c2a7f11b0 (patch)
treeaf92d172cedfc41e544c01a349c1d3b30bc3ff85 /tensorflow/java
parent3cee10e61c1c90734317c62ea3388ec44acc8d08 (diff)
Merge changes from github.
PiperOrigin-RevId: 202585094
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD5
-rw-r--r--tensorflow/java/maven/.gitignore6
-rw-r--r--tensorflow/java/maven/README.md6
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml24
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/run_inside_container.sh47
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml24
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc11
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h2
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java348
10 files changed, 441 insertions, 34 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 19d2133a55..73e210fae0 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -56,6 +56,10 @@ java_library(
srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]),
javacopts = JAVACOPTS,
resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]),
+ deps = [
+ "@com_google_guava",
+ "@com_squareup_javapoet",
+ ],
)
filegroup(
@@ -70,6 +74,7 @@ tf_java_op_gen_srcjar(
name = "java_op_gen_sources",
api_def_srcs = [
"//tensorflow/core/api_def:base_api_def",
+ "//tensorflow/core/api_def:java_api_def",
],
base_package = "org.tensorflow.op",
gen_tool = ":java_op_gen_tool",
diff --git a/tensorflow/java/maven/.gitignore b/tensorflow/java/maven/.gitignore
index ff080515d5..657e2a60bc 100644
--- a/tensorflow/java/maven/.gitignore
+++ b/tensorflow/java/maven/.gitignore
@@ -11,4 +11,10 @@ tensorflow/src
tensorflow/target
proto/src
proto/target
+hadoop/src
+hadoop/target
+spark-connector/src
+spark-connector/target
+spark-connector/dependency-reduced-pom.xml
+spark-connector/spark-warehouse
pom.xml.versionsBackup
diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md
index c7e8f03806..3e030dcd09 100644
--- a/tensorflow/java/maven/README.md
+++ b/tensorflow/java/maven/README.md
@@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release:
7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings
shared by all of the above.
+8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop.
+ The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop)
+
+9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord
+ using Apache Spark DataFrames. The source code for this package is available
+ in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector)
## Updating the release
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
new file mode 100644
index 0000000000..0642be06fa
--- /dev/null
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -0,0 +1,24 @@
+<project
+ xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build -->
+ <modelVersion>4.0.0</modelVersion>
+ <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+ <artifactId>hadoop</artifactId>
+ <packaging>jar</packaging>
+
+ <scm>
+ <url>https://github.com/tensorflow/ecosystem.git</url>
+ <connection>git@github.com:tensorflow/ecosystem.git</connection>
+ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
+ </scm>
+
+ <url>https://github.com/tensorflow/ecosystem/</url>
+ <parent>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>parentpom</artifactId>
+ <version>1.9.0-rc0</version>
+ <relativePath>../</relativePath>
+ </parent>
+</project> \ No newline at end of file
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 3890f3fcaa..b4746794ea 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -32,6 +32,8 @@
<module>libtensorflow_jni_gpu</module>
<module>tensorflow</module>
<module>proto</module>
+ <module>hadoop</module>
+ <module>spark-connector</module>
</modules>
<!-- Two profiles are used:
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index bf19c09b1d..2e771064e4 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -19,6 +19,7 @@
RELEASE_URL_PREFIX="https://storage.googleapis.com/tensorflow/libtensorflow"
+TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
# By default we deploy to both ossrh and bintray. These two
# environment variables can be set to skip either repository.
@@ -44,7 +45,9 @@ clean() {
# (though if run inside a clean docker container, there won't be any dirty
# artifacts lying around)
mvn -q clean
- rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target libtensorflow/src libtensorflow/target tensorflow-android/target
+ rm -rf libtensorflow_jni/src libtensorflow_jni/target libtensorflow_jni_gpu/src libtensorflow_jni_gpu/target \
+ libtensorflow/src libtensorflow/target tensorflow-android/target proto/src proto/target \
+ hadoop/src hadoop/target spark-connector/src spark-connector/target
}
update_version_in_pom() {
@@ -183,6 +186,43 @@ generate_java_protos() {
rm -rf "${DIR}/proto/tmp"
}
+
+# Download the TensorFlow ecosystem source from git.
+# The pom files from this repo do not inherit from the parent pom so the maven version
+# is updated for each module.
+download_tf_ecosystem() {
+ ECOSYSTEM_DIR="/tmp/tensorflow-ecosystem"
+ HADOOP_DIR="${DIR}/hadoop"
+ SPARK_DIR="${DIR}/spark-connector"
+
+ # Clean any previous attempts
+ rm -rf "${ECOSYSTEM_DIR}"
+
+ # Clone the TensorFlow ecosystem project
+ mkdir -p "${ECOSYSTEM_DIR}"
+ cd "${ECOSYSTEM_DIR}"
+ git clone "${TF_ECOSYSTEM_URL}"
+ cd ecosystem
+ git checkout r${TF_VERSION}
+
+ # Copy the TensorFlow Hadoop source
+ cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
+ cp "${ECOSYSTEM_DIR}/ecosystem/hadoop/pom.xml" "${HADOOP_DIR}"
+ cd "${HADOOP_DIR}"
+ update_version_in_pom
+
+ # Copy the TensorFlow Spark connector source
+ cp -r "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/src" "${SPARK_DIR}"
+ cp "${ECOSYSTEM_DIR}/ecosystem/spark/spark-tensorflow-connector/pom.xml" "${SPARK_DIR}"
+ cd "${SPARK_DIR}"
+ update_version_in_pom
+
+ # Cleanup
+ rm -rf "${ECOSYSTEM_DIR}"
+
+ cd "${DIR}"
+}
+
# Deploy artifacts using a specific profile.
# Arguments:
# profile - name of selected profile.
@@ -240,7 +280,8 @@ cd "${DIR}"
# Comment lines out appropriately if debugging/tinkering with the release
# process.
# gnupg2 is required for signing
-apt-get -qq update && apt-get -qqq install -y gnupg2
+apt-get -qq update && apt-get -qqq install -y gnupg2 git
+
clean
update_version_in_pom
download_libtensorflow
@@ -248,6 +289,8 @@ download_libtensorflow_jni
download_libtensorflow_jni_gpu
update_tensorflow_android
generate_java_protos
+download_tf_ecosystem
+
# Build the release artifacts
mvn verify
# Push artifacts to repository
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
new file mode 100644
index 0000000000..19c752d08b
--- /dev/null
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -0,0 +1,24 @@
+<project
+ xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build -->
+ <modelVersion>4.0.0</modelVersion>
+ <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
+ <artifactId>spark-connector</artifactId>
+ <packaging>jar</packaging>
+
+ <scm>
+ <url>https://github.com/tensorflow/ecosystem.git</url>
+ <connection>git@github.com:tensorflow/ecosystem.git</connection>
+ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
+ </scm>
+
+ <url>https://github.com/tensorflow/ecosystem/</url>
+ <parent>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>parentpom</artifactId>
+ <version>1.9.0-rc0</version>
+ <relativePath>../</relativePath>
+ </parent>
+</project> \ No newline at end of file
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index 9b171f66ec..d5bd99bdd9 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -35,7 +35,7 @@ namespace tensorflow {
namespace java {
namespace {
-const char* kLicense =
+constexpr const char kLicense[] =
"/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
"\n"
"Licensed under the Apache License, Version 2.0 (the \"License\");\n"
@@ -391,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint,
}
if (!op.hidden()) {
// expose the op in the Ops Graph API only if it is visible
- op_class.add_annotation(
- Annotation::Create("Operator", "org.tensorflow.op.annotation")
- .attributes("group = \"" + endpoint.package() + "\""));
+ Annotation oper_annot =
+ Annotation::Create("Operator", "org.tensorflow.op.annotation");
+ if (endpoint.package() != kDefaultEndpointPackage) {
+ oper_annot.attributes("group = \"" + endpoint.package() + "\"");
+ }
+ op_class.add_annotation(oper_annot);
}
// create op class file
const string op_dir_name = io::JoinPath(
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index ca0ba16745..30ecb8ce53 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -27,6 +27,8 @@ limitations under the License.
namespace tensorflow {
namespace java {
+constexpr const char kDefaultEndpointPackage[] = "core";
+
class EndpointSpec {
public:
// A specification for an operation endpoint
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 11fda4fc22..796d6a62dc 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -15,19 +15,44 @@ limitations under the License.
package org.tensorflow.processor;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.FieldSpec;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeSpec;
+import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
-import java.io.PrintWriter;
+import java.util.Collection;
import java.util.Collections;
-import java.util.HashSet;
+import java.util.HashMap;
+import java.util.Map;
import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
import javax.annotation.processing.ProcessingEnvironment;
import javax.annotation.processing.RoundEnvironment;
import javax.lang.model.SourceVersion;
+import javax.lang.model.element.AnnotationMirror;
+import javax.lang.model.element.AnnotationValue;
import javax.lang.model.element.Element;
+import javax.lang.model.element.ExecutableElement;
+import javax.lang.model.element.Modifier;
import javax.lang.model.element.TypeElement;
+import javax.lang.model.element.TypeParameterElement;
+import javax.lang.model.element.VariableElement;
+import javax.lang.model.type.TypeMirror;
+import javax.lang.model.type.TypeVariable;
+import javax.lang.model.util.ElementFilter;
+import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;
/**
@@ -55,6 +80,7 @@ public final class OperatorProcessor extends AbstractProcessor {
super.init(processingEnv);
messager = processingEnv.getMessager();
filer = processingEnv.getFiler();
+ elements = processingEnv.getElementUtils();
}
@Override
@@ -98,42 +124,77 @@ public final class OperatorProcessor extends AbstractProcessor {
}
// Collect all classes tagged with our annotation.
- Set<TypeElement> opClasses = new HashSet<TypeElement>();
- if (!collectOpClasses(roundEnv, opClasses, annotation)) {
+ Multimap<String, MethodSpec> groupedMethods = HashMultimap.create();
+ if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) {
return true;
}
// Nothing to do when there are no tagged classes.
- if (opClasses.isEmpty()) {
+ if (groupedMethods.isEmpty()) {
return true;
}
- // TODO:(kbsriram) validate operator classes and generate Op API.
- writeApi();
+ // Validate operator classes and generate Op API.
+ writeApi(groupedMethods);
+
hasRun = true;
return true;
}
@Override
public Set<String> getSupportedAnnotationTypes() {
- return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE));
+ return Collections.singleton("org.tensorflow.op.annotation.Operator");
+ }
+
+ private static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
+ private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
+ private static final TypeName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
+ private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
+ private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
+ private static final TypeName T_STRING = ClassName.get(String.class);
+
+ private Filer filer;
+ private Messager messager;
+ private Elements elements;
+ private boolean hasRun = false;
+
+ private void error(Element e, String message, Object... args) {
+ if (args != null && args.length > 0) {
+ message = String.format(message, args);
+ }
+ messager.printMessage(Kind.ERROR, message, e);
}
- private void writeApi() {
- // Generate an empty class for now and get the build working correctly. This will be changed to
- // generate the actual API once we've done with build-related changes.
- // TODO:(kbsriram)
- try (PrintWriter writer =
- new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) {
- writer.println(String.format("package %s;", OP_PACKAGE));
- writer.println("public class Ops{}");
+ private void write(TypeSpec spec) {
+ try {
+ JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
} catch (IOException e) {
- error(null, "Unexpected failure generating API: %s", e.getMessage());
+ throw new AssertionError(e);
+ }
+ }
+
+ private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
+ Map<String, ClassName> groups = new HashMap<>();
+
+ // Generate a API class for each group collected other than the default one (= empty string)
+ for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
+ if (!entry.getKey().isEmpty()) {
+ TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
+ write(groupClass);
+ groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name));
+ }
}
+ // Generate the top API class, adding any methods added to the default group
+ TypeSpec topClass = buildTopClass(groups, groupedMethods.get(""));
+ write(topClass);
}
- private boolean collectOpClasses(
- RoundEnvironment roundEnv, Set<TypeElement> opClasses, TypeElement annotation) {
+ private boolean collectOpsMethods(
+ RoundEnvironment roundEnv,
+ Multimap<String, MethodSpec> groupedMethods,
+ TypeElement annotation) {
boolean result = true;
for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
// @Operator can only apply to types, so e must be a TypeElement.
@@ -145,20 +206,251 @@ public final class OperatorProcessor extends AbstractProcessor {
result = false;
continue;
}
- opClasses.add((TypeElement) e);
+ TypeElement opClass = (TypeElement) e;
+ // Skip deprecated operations for now, as we do not guarantee API stability yet
+ if (opClass.getAnnotation(Deprecated.class) == null) {
+ collectOpMethods(groupedMethods, opClass, annotation);
+ }
}
return result;
}
- private void error(Element e, String message, Object... args) {
- if (args != null && args.length > 0) {
- message = String.format(message, args);
+ private void collectOpMethods(
+ Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
+ AnnotationMirror am = getAnnotationMirror(opClass, annotation);
+ String groupName = getAnnotationElementValueAsString("group", am);
+ String methodName = getAnnotationElementValueAsString("name", am);
+ ClassName opClassName = ClassName.get(opClass);
+ if (Strings.isNullOrEmpty(methodName)) {
+ methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
+ }
+ // Build a method for each @Operator found in the class path. There should be one method per
+ // operation factory called
+ // "create", which takes in parameter a scope and, optionally, a list of arguments
+ for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
+ if (opMethod.getModifiers().contains(Modifier.STATIC)
+ && opMethod.getSimpleName().contentEquals("create")) {
+ MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
+ groupedMethods.put(groupName, method);
+ }
}
- messager.printMessage(Kind.ERROR, message, e);
}
- private Filer filer;
- private Messager messager;
- private boolean hasRun = false;
- private static final String OP_PACKAGE = "org.tensorflow.op";
+ private MethodSpec buildOpMethod(
+ String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
+ MethodSpec.Builder builder =
+ MethodSpec.methodBuilder(methodName)
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(factoryMethod.getReturnType()))
+ .varargs(factoryMethod.isVarArgs())
+ .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
+
+ for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
+ TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
+ builder.addTypeVariable(tvn);
+ }
+ for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
+ builder.addException(TypeName.get(thrownType));
+ }
+ StringBuilder call = new StringBuilder("return $T.create(scope");
+ boolean first = true;
+ for (VariableElement param : factoryMethod.getParameters()) {
+ ParameterSpec p = ParameterSpec.get(param);
+ if (first) {
+ first = false;
+ continue;
+ }
+ call.append(", ");
+ call.append(p.name);
+ builder.addParameter(p);
+ }
+ call.append(")");
+ builder.addStatement(call.toString(), opClassName);
+ return builder.build();
+ }
+
+ private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
+ StringBuilder javadoc = new StringBuilder();
+ javadoc
+ .append("Adds an {@link ")
+ .append(opClassName.simpleName())
+ .append("} operation to the graph\n\n");
+
+ // Add all javadoc tags found in the operator factory method but the first one, which should be
+ // in all cases the
+ // 'scope' parameter that is implicitly passed by this API
+ Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
+ boolean firstParam = true;
+
+ while (tagMatcher.find()) {
+ String tag = tagMatcher.group();
+ if (tag.startsWith("@param") && firstParam) {
+ firstParam = false;
+ } else {
+ javadoc.append(tag).append('\n');
+ }
+ }
+ javadoc.append("@see {@link ").append(opClassName).append("}\n");
+
+ return javadoc.toString();
+ }
+
+ private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
+ MethodSpec.Builder ctorBuilder =
+ MethodSpec.constructorBuilder()
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope");
+
+ TypeSpec.Builder builder =
+ TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
+ + "@see {@link $T}\n",
+ group,
+ T_GRAPH,
+ T_OPS)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
+
+ builder.addField(
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+
+ return builder.build();
+ }
+
+ private static TypeSpec buildTopClass(
+ Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
+ MethodSpec.Builder ctorBuilder =
+ MethodSpec.constructorBuilder()
+ .addModifiers(Modifier.PRIVATE)
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope", T_SCOPE);
+
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
+ ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
+ }
+
+ TypeSpec.Builder opsBuilder =
+ TypeSpec.classBuilder("Ops")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for building a {@link $T} with operation wrappers\n<p>\n"
+ + "Any operation wrapper found in the classpath properly annotated as an"
+ + "{@link $T @Operator} is exposed\n"
+ + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
+ + "try (Graph g = new Graph()) {\n"
+ + " Ops ops = new Ops(g);\n"
+ + " // Operations are typed classes with convenience\n"
+ + " // builders in Ops.\n"
+ + " Constant three = ops.constant(3);\n"
+ + " // Single-result operations implement the Operand\n"
+ + " // interface, so this works too.\n"
+ + " Operand four = ops.constant(4);\n"
+ + " // Most builders are found within a group, and accept\n"
+ + " // Operand types as operands\n"
+ + " Operand nine = ops.math().add(four, ops.constant(5));\n"
+ + " // Multi-result operations however offer methods to\n"
+ + " // select a particular result for use.\n"
+ + " Operand result = \n"
+ + " ops.math().add(ops.array().unique(s, a).y(), b);\n"
+ + " // Optional attributes\n"
+ + " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+ + " // Naming operators\n"
+ + " ops.withName(“foo”).constant(5); // name “foo”\n"
+ + " // Names can exist in a hierarchy\n"
+ + " Ops sub = ops.withSubScope(“sub”);\n"
+ + " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ + "}\n"
+ + "}</pre>\n",
+ T_GRAPH,
+ T_OPERATOR)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("withSubScope")
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "childScopeName")
+ .returns(T_OPS)
+ .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
+ .addJavadoc(
+ "Returns an API that adds operations to the graph with the provided name prefix.\n"
+ + "\n@see {@link $T#withSubScope(String)}\n",
+ T_SCOPE)
+ .build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("withName")
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "opName")
+ .returns(T_OPS)
+ .addStatement("return new Ops(scope.withName(opName))")
+ .addJavadoc(
+ "Returns an API that uses the provided name for an op.\n\n"
+ + "@see {@link $T#withName(String)}\n",
+ T_SCOPE)
+ .build());
+
+ opsBuilder.addField(
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("scope")
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(T_SCOPE)
+ .addStatement("return scope")
+ .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
+ .build());
+
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
+ opsBuilder.addField(
+ FieldSpec.builder(entry.getValue(), entry.getKey())
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder(entry.getKey())
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(entry.getValue())
+ .addStatement("return $L", entry.getKey())
+ .addJavadoc(
+ "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
+ .build());
+ }
+
+ opsBuilder.addMethod(
+ MethodSpec.methodBuilder("create")
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addParameter(T_GRAPH, "graph")
+ .returns(T_OPS)
+ .addStatement("return new Ops(new $T(graph))", T_SCOPE)
+ .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
+ .build());
+
+ return opsBuilder.build();
+ }
+
+ private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) {
+ for (AnnotationMirror am : element.getAnnotationMirrors()) {
+ if (am.getAnnotationType().asElement().equals(annotation)) {
+ return am;
+ }
+ }
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + element.getSimpleName());
+ }
+
+ private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
+ for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
+ if (entry.getKey().getSimpleName().contentEquals(elementName)) {
+ return entry.getValue().getValue().toString();
+ }
+ }
+ return "";
+ }
}