From 1e7b0e4ad6d0f57f3241fe0b80a65f2c2a7f11b0 Mon Sep 17 00:00:00 2001 From: Mingxing Tan Date: Thu, 28 Jun 2018 19:13:20 -0700 Subject: Merge changes from github. PiperOrigin-RevId: 202585094 --- tensorflow/java/BUILD | 5 + tensorflow/java/maven/.gitignore | 6 + tensorflow/java/maven/README.md | 6 + tensorflow/java/maven/hadoop/pom.xml | 24 ++ tensorflow/java/maven/pom.xml | 2 + tensorflow/java/maven/run_inside_container.sh | 47 ++- tensorflow/java/maven/spark-connector/pom.xml | 24 ++ tensorflow/java/src/gen/cc/op_generator.cc | 11 +- tensorflow/java/src/gen/cc/op_specs.h | 2 + .../tensorflow/processor/OperatorProcessor.java | 348 +++++++++++++++++++-- 10 files changed, 441 insertions(+), 34 deletions(-) create mode 100644 tensorflow/java/maven/hadoop/pom.xml create mode 100644 tensorflow/java/maven/spark-connector/pom.xml (limited to 'tensorflow/java') diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 19d2133a55..73e210fae0 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -56,6 +56,10 @@ java_library( srcs = glob(["src/gen/java/org/tensorflow/processor/**/*.java"]), javacopts = JAVACOPTS, resources = glob(["src/gen/resources/META-INF/services/javax.annotation.processing.Processor"]), + deps = [ + "@com_google_guava", + "@com_squareup_javapoet", + ], ) filegroup( @@ -70,6 +74,7 @@ tf_java_op_gen_srcjar( name = "java_op_gen_sources", api_def_srcs = [ "//tensorflow/core/api_def:base_api_def", + "//tensorflow/core/api_def:java_api_def", ], base_package = "org.tensorflow.op", gen_tool = ":java_op_gen_tool", diff --git a/tensorflow/java/maven/.gitignore b/tensorflow/java/maven/.gitignore index ff080515d5..657e2a60bc 100644 --- a/tensorflow/java/maven/.gitignore +++ b/tensorflow/java/maven/.gitignore @@ -11,4 +11,10 @@ tensorflow/src tensorflow/target proto/src proto/target +hadoop/src +hadoop/target +spark-connector/src +spark-connector/target +spark-connector/dependency-reduced-pom.xml +spark-connector/spark-warehouse pom.xml.versionsBackup diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md index c7e8f03806..3e030dcd09 100644 --- a/tensorflow/java/maven/README.md +++ b/tensorflow/java/maven/README.md @@ -53,6 +53,12 @@ There are seven artifacts and thus `pom.xml`s involved in this release: 7. [`parentpom`](https://maven.apache.org/pom/index.html): Common settings shared by all of the above. +8. `hadoop`: The TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop. + The source code for this package is available in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/hadoop) + +9. `spark-connector`: A Scala library for loading and storing TensorFlow TFRecord + using Apache Spark DataFrames. The source code for this package is available + in the [TensorFlow Ecosystem](https://github.com/tensorflow/ecosystem/tree/master/spark/spark-tensorflow-connector) ## Updating the release diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml new file mode 100644 index 0000000000..0642be06fa --- /dev/null +++ b/tensorflow/java/maven/hadoop/pom.xml @@ -0,0 +1,24 @@ + + + 4.0.0 + TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop + hadoop + jar + + + https://github.com/tensorflow/ecosystem.git + git@github.com:tensorflow/ecosystem.git + scm:git:https://github.com/tensorflow/ecosystem.git + + + https://github.com/tensorflow/ecosystem/ + + org.tensorflow + parentpom + 1.9.0-rc0 + ../ + + \ No newline at end of file diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 3890f3fcaa..b4746794ea 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -32,6 +32,8 @@ libtensorflow_jni_gpu tensorflow proto + hadoop + spark-connector + 4.0.0 + TensorFlow TFRecord connector for Apache Spark DataFrames + spark-connector + jar + + + https://github.com/tensorflow/ecosystem.git + git@github.com:tensorflow/ecosystem.git + scm:git:https://github.com/tensorflow/ecosystem.git + + + https://github.com/tensorflow/ecosystem/ + + org.tensorflow + parentpom + 1.9.0-rc0 + ../ + + \ No newline at end of file diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 9b171f66ec..d5bd99bdd9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -35,7 +35,7 @@ namespace tensorflow { namespace java { namespace { -const char* kLicense = +constexpr const char kLicense[] = "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" "\n" "Licensed under the Apache License, Version 2.0 (the \"License\");\n" @@ -391,9 +391,12 @@ void GenerateOp(const OpSpec& op, const EndpointSpec& endpoint, } if (!op.hidden()) { // expose the op in the Ops Graph API only if it is visible - op_class.add_annotation( - Annotation::Create("Operator", "org.tensorflow.op.annotation") - .attributes("group = \"" + endpoint.package() + "\"")); + Annotation oper_annot = + Annotation::Create("Operator", "org.tensorflow.op.annotation"); + if (endpoint.package() != kDefaultEndpointPackage) { + oper_annot.attributes("group = \"" + endpoint.package() + "\""); + } + op_class.add_annotation(oper_annot); } // create op class file const string op_dir_name = io::JoinPath( diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index ca0ba16745..30ecb8ce53 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -27,6 +27,8 @@ limitations under the License. namespace tensorflow { namespace java { +constexpr const char kDefaultEndpointPackage[] = "core"; + class EndpointSpec { public: // A specification for an operation endpoint diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 11fda4fc22..796d6a62dc 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,19 +15,44 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; -import java.io.PrintWriter; +import java.util.Collection; import java.util.Collections; -import java.util.HashSet; +import java.util.HashMap; +import java.util.Map; import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; import javax.annotation.processing.ProcessingEnvironment; import javax.annotation.processing.RoundEnvironment; import javax.lang.model.SourceVersion; +import javax.lang.model.element.AnnotationMirror; +import javax.lang.model.element.AnnotationValue; import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; import javax.lang.model.element.TypeElement; +import javax.lang.model.element.TypeParameterElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeVariable; +import javax.lang.model.util.ElementFilter; +import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; /** @@ -55,6 +80,7 @@ public final class OperatorProcessor extends AbstractProcessor { super.init(processingEnv); messager = processingEnv.getMessager(); filer = processingEnv.getFiler(); + elements = processingEnv.getElementUtils(); } @Override @@ -98,42 +124,77 @@ public final class OperatorProcessor extends AbstractProcessor { } // Collect all classes tagged with our annotation. - Set opClasses = new HashSet(); - if (!collectOpClasses(roundEnv, opClasses, annotation)) { + Multimap groupedMethods = HashMultimap.create(); + if (!collectOpsMethods(roundEnv, groupedMethods, annotation)) { return true; } // Nothing to do when there are no tagged classes. - if (opClasses.isEmpty()) { + if (groupedMethods.isEmpty()) { return true; } - // TODO:(kbsriram) validate operator classes and generate Op API. - writeApi(); + // Validate operator classes and generate Op API. + writeApi(groupedMethods); + hasRun = true; return true; } @Override public Set getSupportedAnnotationTypes() { - return Collections.singleton(String.format("%s.annotation.Operator", OP_PACKAGE)); + return Collections.singleton("org.tensorflow.op.annotation.Operator"); + } + + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); + private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); + private static final TypeName T_STRING = ClassName.get(String.class); + + private Filer filer; + private Messager messager; + private Elements elements; + private boolean hasRun = false; + + private void error(Element e, String message, Object... args) { + if (args != null && args.length > 0) { + message = String.format(message, args); + } + messager.printMessage(Kind.ERROR, message, e); } - private void writeApi() { - // Generate an empty class for now and get the build working correctly. This will be changed to - // generate the actual API once we've done with build-related changes. - // TODO:(kbsriram) - try (PrintWriter writer = - new PrintWriter(filer.createSourceFile(String.format("%s.Ops", OP_PACKAGE)).openWriter())) { - writer.println(String.format("package %s;", OP_PACKAGE)); - writer.println("public class Ops{}"); + private void write(TypeSpec spec) { + try { + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { - error(null, "Unexpected failure generating API: %s", e.getMessage()); + throw new AssertionError(e); + } + } + + private void writeApi(Multimap groupedMethods) { + Map groups = new HashMap<>(); + + // Generate a API class for each group collected other than the default one (= empty string) + for (Map.Entry> entry : groupedMethods.asMap().entrySet()) { + if (!entry.getKey().isEmpty()) { + TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); + write(groupClass); + groups.put(entry.getKey(), ClassName.get("org.tensorflow.op", groupClass.name)); + } } + // Generate the top API class, adding any methods added to the default group + TypeSpec topClass = buildTopClass(groups, groupedMethods.get("")); + write(topClass); } - private boolean collectOpClasses( - RoundEnvironment roundEnv, Set opClasses, TypeElement annotation) { + private boolean collectOpsMethods( + RoundEnvironment roundEnv, + Multimap groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. @@ -145,20 +206,251 @@ public final class OperatorProcessor extends AbstractProcessor { result = false; continue; } - opClasses.add((TypeElement) e); + TypeElement opClass = (TypeElement) e; + // Skip deprecated operations for now, as we do not guarantee API stability yet + if (opClass.getAnnotation(Deprecated.class) == null) { + collectOpMethods(groupedMethods, opClass, annotation); + } } return result; } - private void error(Element e, String message, Object... args) { - if (args != null && args.length > 0) { - message = String.format(message, args); + private void collectOpMethods( + Multimap groupedMethods, TypeElement opClass, TypeElement annotation) { + AnnotationMirror am = getAnnotationMirror(opClass, annotation); + String groupName = getAnnotationElementValueAsString("group", am); + String methodName = getAnnotationElementValueAsString("name", am); + ClassName opClassName = ClassName.get(opClass); + if (Strings.isNullOrEmpty(methodName)) { + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + } + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called + // "create", which takes in parameter a scope and, optionally, a list of arguments + for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { + MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); + groupedMethods.put(groupName, method); + } } - messager.printMessage(Kind.ERROR, message, e); } - private Filer filer; - private Messager messager; - private boolean hasRun = false; - private static final String OP_PACKAGE = "org.tensorflow.op"; + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + MethodSpec.Builder builder = + MethodSpec.methodBuilder(methodName) + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { + TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); + builder.addTypeVariable(tvn); + } + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { + builder.addException(TypeName.get(thrownType)); + } + StringBuilder call = new StringBuilder("return $T.create(scope"); + boolean first = true; + for (VariableElement param : factoryMethod.getParameters()) { + ParameterSpec p = ParameterSpec.get(param); + if (first) { + first = false; + continue; + } + call.append(", "); + call.append(p.name); + builder.addParameter(p); + } + call.append(")"); + builder.addStatement(call.toString(), opClassName); + return builder.build(); + } + + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { + StringBuilder javadoc = new StringBuilder(); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); + + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the + // 'scope' parameter that is implicitly passed by this API + Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); + boolean firstParam = true; + + while (tagMatcher.find()) { + String tag = tagMatcher.group(); + if (tag.startsWith("@param") && firstParam) { + firstParam = false; + } else { + javadoc.append(tag).append('\n'); + } + } + javadoc.append("@see {@link ").append(opClassName).append("}\n"); + + return javadoc.toString(); + } + + private static TypeSpec buildGroupClass(String group, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + + TypeSpec.Builder builder = + TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + builder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + return builder.build(); + } + + private static TypeSpec buildTopClass( + Map groupToClass, Collection methods) { + MethodSpec.Builder ctorBuilder = + MethodSpec.constructorBuilder() + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); + + for (Map.Entry entry : groupToClass.entrySet()) { + ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); + } + + TypeSpec.Builder opsBuilder = + TypeSpec.classBuilder("Ops") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n

\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n

Example usage:\n

{@code\n"
+                    + "try (Graph g = new Graph()) {\n"
+                    + "  Ops ops = new Ops(g);\n"
+                    + "  // Operations are typed classes with convenience\n"
+                    + "  // builders in Ops.\n"
+                    + "  Constant three = ops.constant(3);\n"
+                    + "  // Single-result operations implement the Operand\n"
+                    + "  // interface, so this works too.\n"
+                    + "  Operand four = ops.constant(4);\n"
+                    + "  // Most builders are found within a group, and accept\n"
+                    + "  // Operand types as operands\n"
+                    + "  Operand nine = ops.math().add(four, ops.constant(5));\n"
+                    + "  // Multi-result operations however offer methods to\n"
+                    + "  // select a particular result for use.\n"
+                    + "  Operand result = \n"
+                    + "      ops.math().add(ops.array().unique(s, a).y(), b);\n"
+                    + "  // Optional attributes\n"
+                    + "  ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+                    + "  // Naming operators\n"
+                    + "  ops.withName(“foo”).constant(5); // name “foo”\n"
+                    + "  // Names can exist in a hierarchy\n"
+                    + "  Ops sub = ops.withSubScope(“sub”);\n"
+                    + "  sub.withName(“bar”).constant(4); // “sub/bar”\n"
+                    + "}\n"
+                    + "}
\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withSubScope") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("withName") + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); + + opsBuilder.addField( + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder("scope") + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); + + for (Map.Entry entry : groupToClass.entrySet()) { + opsBuilder.addField( + FieldSpec.builder(entry.getValue(), entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + + opsBuilder.addMethod( + MethodSpec.methodBuilder(entry.getKey()) + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); + } + + opsBuilder.addMethod( + MethodSpec.methodBuilder("create") + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); + + return opsBuilder.build(); + } + + private static AnnotationMirror getAnnotationMirror(Element element, TypeElement annotation) { + for (AnnotationMirror am : element.getAnnotationMirrors()) { + if (am.getAnnotationType().asElement().equals(annotation)) { + return am; + } + } + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); + } + + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { + for (Map.Entry entry : + am.getElementValues().entrySet()) { + if (entry.getKey().getSimpleName().contentEquals(elementName)) { + return entry.getValue().getValue().toString(); + } + } + return ""; + } } -- cgit v1.2.3