diff options
Diffstat (limited to 'tensorflow/java')
35 files changed, 1596 insertions, 344 deletions
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml index 0642be06fa..a160377210 100644 --- a/tensorflow/java/maven/hadoop/pom.xml +++ b/tensorflow/java/maven/hadoop/pom.xml @@ -1,12 +1,30 @@ -<project - xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - <!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build --> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> - <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description> + <groupId>org.tensorflow</groupId> <artifactId>hadoop</artifactId> <packaging>jar</packaging> + <version>1.9.0-rc2</version> + <name>tensorflow-hadoop</name> + <url>https://www.tensorflow.org</url> + <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description> + + <properties> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <maven.compiler.source>1.6</maven.compiler.source> + <maven.compiler.target>1.6</maven.compiler.target> + <hadoop.version>2.6.0</hadoop.version> + <protobuf.version>3.3.1</protobuf.version> + <junit.version>4.11</junit.version> + </properties> + + <licenses> + <license> + <name>Apache License Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + </license> + </licenses> <scm> <url>https://github.com/tensorflow/ecosystem.git</url> @@ -14,11 +32,133 @@ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection> </scm> - <url>https://github.com/tensorflow/ecosystem/</url> - <parent> - <groupId>org.tensorflow</groupId> - <artifactId>parentpom</artifactId> - <version>1.9.0-rc0</version> - <relativePath>../</relativePath> - </parent> -</project>
\ No newline at end of file + <build> + <pluginManagement> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + <version>1.5</version> + <executions> + <execution> + <id>sign-artifacts</id> + <phase>verify</phase> + <goals> + <goal>sign</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </pluginManagement> + </build> + + <dependencies> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>proto</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-common</artifactId> + <version>${hadoop.version}</version> + <exclusions> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-core</artifactId> + <version>${hadoop.version}</version> + <exclusions> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${protobuf.version}</version> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <version>${junit.version}</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-mapreduce-client-jobclient</artifactId> + <version>${hadoop.version}</version> + <type>test-jar</type> + <optional>true</optional> + <scope>test</scope> + <exclusions> + <exclusion> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + </exclusion> + </exclusions> + </dependency> + </dependencies> + + <!-- Two profiles are used: + ossrh - deploys to ossrh/maven central + bintray - deploys to bintray/jcenter. --> + <profiles> + <profile> + <id>ossrh</id> + <distributionManagement> + <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html --> + <snapshotRepository> + <id>ossrh</id> + <url>https://oss.sonatype.org/content/repositories/snapshots</url> + </snapshotRepository> + <repository> + <id>ossrh</id> + <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url> + </repository> + </distributionManagement> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>bintray</id> + <distributionManagement> + <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ --> + <repository> + <id>bintray</id> + <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url> + </repository> + </distributionManagement> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + </plugin> + </plugins> + </build> + </profile> + </profiles> + + <developers> + <developer> + <name>TensorFlowers</name> + <organization>TensorFlow</organization> + <organizationUrl>http://www.tensorflow.org</organizationUrl> + </developer> + </developers> +</project> diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index a7fa9ea5cc..489fb7bdb0 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 83aae29f1e..5bef85f75e 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 50bd8ee5f9..8d93c78220 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni_gpu</artifactId> diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index b4746794ea..c5861102c8 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ <modelVersion>4.0.0</modelVersion> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <packaging>pom</packaging> <url>https://www.tensorflow.org</url> diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 618a2a124c..754caad900 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <relativePath>../</relativePath> </parent> <artifactId>proto</artifactId> diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh index 2e771064e4..2240d6b7b9 100644 --- a/tensorflow/java/maven/run_inside_container.sh +++ b/tensorflow/java/maven/run_inside_container.sh @@ -203,7 +203,10 @@ download_tf_ecosystem() { cd "${ECOSYSTEM_DIR}" git clone "${TF_ECOSYSTEM_URL}" cd ecosystem - git checkout r${TF_VERSION} + # TF_VERSION is a semver string (<major>.<minor>.<patch>[-suffix]) + # but the branch is just (r<major>.<minor>). + RELEASE_BRANCH=$(echo "${TF_VERSION}" | sed -e 's/\([0-9]\+\.[0-9]\+\)\.[0-9]\+.*/\1/') + git checkout r${RELEASE_BRANCH} # Copy the TensorFlow Hadoop source cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}" diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml index 19c752d08b..99237fdb98 100644 --- a/tensorflow/java/maven/spark-connector/pom.xml +++ b/tensorflow/java/maven/spark-connector/pom.xml @@ -1,12 +1,23 @@ -<project - xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - <!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build --> +<?xml version="1.0" encoding="UTF-8"?> +<project xmlns="http://maven.apache.org/POM/4.0.0" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> - <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description> - <artifactId>spark-connector</artifactId> + <groupId>org.tensorflow</groupId> + <artifactId>spark-connector_2.11</artifactId> <packaging>jar</packaging> + <version>1.9.0-rc2</version> + <name>spark-tensorflow-connector</name> + <url>https://www.tensorflow.org</url> + <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description> + + <licenses> + <license> + <name>The Apache Software License, Version 2.0</name> + <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url> + <distribution>repo</distribution> + </license> + </licenses> <scm> <url>https://github.com/tensorflow/ecosystem.git</url> @@ -14,11 +25,293 @@ <developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection> </scm> - <url>https://github.com/tensorflow/ecosystem/</url> - <parent> - <groupId>org.tensorflow</groupId> - <artifactId>parentpom</artifactId> - <version>1.9.0-rc0</version> - <relativePath>../</relativePath> - </parent> -</project>
\ No newline at end of file + <properties> + <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> + <scala.maven.version>3.2.2</scala.maven.version> + <scala.binary.version>2.11</scala.binary.version> + <scalatest.maven.version>1.0</scalatest.maven.version> + <scala.test.version>2.2.6</scala.test.version> + <maven.compiler.version>3.0</maven.compiler.version> + <java.version>1.8</java.version> + <spark.version>2.3.0</spark.version> + <yarn.api.version>2.7.3</yarn.api.version> + <junit.version>4.11</junit.version> + </properties> + + <build> + <pluginManagement> + <plugins> + <plugin> + <inherited>true</inherited> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <version>${scala.maven.version}</version> + <executions> + <execution> + <id>compile</id> + <goals> + <goal>add-source</goal> + <goal>compile</goal> + </goals> + <configuration> + <jvmArgs> + <jvmArg>-Xms256m</jvmArg> + <jvmArg>-Xmx512m</jvmArg> + </jvmArgs> + <args> + <arg>-g:vars</arg> + <arg>-deprecation</arg> + <arg>-feature</arg> + <arg>-unchecked</arg> + <arg>-Xfatal-warnings</arg> + <arg>-language:implicitConversions</arg> + <arg>-language:existentials</arg> + </args> + </configuration> + </execution> + <execution> + <id>test</id> + <goals> + <goal>add-source</goal> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + <configuration> + <recompileMode>incremental</recompileMode> + <useZincServer>true</useZincServer> + <scalaVersion>${scala.binary.version}</scalaVersion> + <checkMultipleScalaVersions>false</checkMultipleScalaVersions> + </configuration> + </plugin> + <plugin> + <inherited>true</inherited> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <version>${scalatest.maven.version}</version> + <executions> + <execution> + <id>scalaTest</id> + <phase>test</phase> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + <!-- Shade protobuf dependency. --> + <plugin> + <artifactId>maven-shade-plugin</artifactId> + <version>3.1.0</version> + <executions> + <execution> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <minimizeJar>true</minimizeJar> + <artifactSet> + <includes> + <include>com.google.protobuf:protobuf-java</include> + <include>org.tensorflow:hadoop</include> + <include>org.tensorflow:proto</include> + </includes> + </artifactSet> + <filters> + <filter> + <!-- Remove the source to keep the result smaller. --> + <artifact>com.google.protobuf:protobuf-java</artifact> + <excludes> + <exclude>**/*.java</exclude> + </excludes> + </filter> + </filters> + <relocations> + <relocation> + <pattern>com.google.protobuf</pattern> + <shadedPattern> + org.tensorflow.spark.shaded.com.google.protobuf + </shadedPattern> + </relocation> + </relocations> + </configuration> + </execution> + </executions> + </plugin> + <!-- GPG signed components: http://central.sonatype.org/pages/apache-maven.html#gpg-signed-components --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + <version>1.5</version> + <executions> + <execution> + <id>sign-artifacts</id> + <phase>verify</phase> + <goals> + <goal>sign</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </pluginManagement> + <plugins> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + </plugin> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-compiler-plugin</artifactId> + <version>${maven.compiler.version}</version> + <configuration> + <source>${java.version}</source> + <target>${java.version}</target> + </configuration> + </plugin> + </plugins> + </build> + + <profiles> + <profile> + <id>test</id> + <activation> + <activeByDefault>true</activeByDefault> + <property> + <name>!NEVERSETME</name> + </property> + </activation> + <build> + <plugins> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + </plugin> + </plugins> + </build> + <dependencyManagement> + <dependencies> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <version>${scala.test.version}</version> + <scope>test</scope> + </dependency> + </dependencies> + </dependencyManagement> + <dependencies> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.binary.version}</artifactId> + <scope>test</scope> + </dependency> + </dependencies> + </profile> + + <!-- Two profiles are used: + ossrh - deploys to ossrh/maven central + bintray - deploys to bintray/jcenter. --> + <profile> + <id>ossrh</id> + <distributionManagement> + <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html --> + <snapshotRepository> + <id>ossrh</id> + <url>https://oss.sonatype.org/content/repositories/snapshots</url> + </snapshotRepository> + <repository> + <id>ossrh</id> + <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url> + </repository> + </distributionManagement> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + </plugin> + </plugins> + </build> + </profile> + <profile> + <id>bintray</id> + <distributionManagement> + <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ --> + <repository> + <id>bintray</id> + <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url> + </repository> + </distributionManagement> + <build> + <plugins> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-gpg-plugin</artifactId> + </plugin> + </plugins> + </build> + </profile> + </profiles> + + <developers> + <developer> + <name>TensorFlowers</name> + <organization>TensorFlow</organization> + <organizationUrl>http://www.tensorflow.org</organizationUrl> + </developer> + </developers> + + <dependencies> + <dependency> + <groupId>org.tensorflow</groupId> + <artifactId>hadoop</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-core_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-sql_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.hadoop</groupId> + <artifactId>hadoop-yarn-api</artifactId> + <version>${yarn.api.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-mllib_${scala.binary.version}</artifactId> + <version>${spark.version}</version> + <type>test-jar</type> + <scope>test</scope> + </dependency> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <version>${junit.version}</version> + <scope>test</scope> + </dependency> + </dependencies> +</project> diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 157c4b8e82..2a8e640dbc 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0-rc1</version> + <version>1.9.0-rc2</version> <relativePath>../</relativePath> </parent> <artifactId>tensorflow</artifactId> diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h index f5f54bf4d3..d9d6f8adc8 100644 --- a/tensorflow/java/src/gen/cc/java_defs.h +++ b/tensorflow/java/src/gen/cc/java_defs.h @@ -16,9 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ #define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_ -#include <string> #include <list> #include <map> +#include <string> #include <utility> namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc index 2df69ee299..d5bd99bdd9 100644 --- a/tensorflow/java/src/gen/cc/op_generator.cc +++ b/tensorflow/java/src/gen/cc/op_generator.cc @@ -36,20 +36,21 @@ namespace java { namespace { constexpr const char kLicense[] = - "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" - "\n" - "Licensed under the Apache License, Version 2.0 (the \"License\");\n" - "you may not use this file except in compliance with the License.\n" - "You may obtain a copy of the License at\n" - "\n" - " http://www.apache.org/licenses/LICENSE-2.0\n" - "\n" - "Unless required by applicable law or agreed to in writing, software\n" - "distributed under the License is distributed on an \"AS IS\" BASIS,\n" - "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - "See the License for the specific language governing permissions and\n" - "limitations under the License.\n" - "=======================================================================*/\n"; + "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n" + "\n" + "Licensed under the Apache License, Version 2.0 (the \"License\");\n" + "you may not use this file except in compliance with the License.\n" + "You may obtain a copy of the License at\n" + "\n" + " http://www.apache.org/licenses/LICENSE-2.0\n" + "\n" + "Unless required by applicable law or agreed to in writing, software\n" + "distributed under the License is distributed on an \"AS IS\" BASIS,\n" + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + "See the License for the specific language governing permissions and\n" + "limitations under the License.\n" + "=======================================================================*/" + "\n"; // There is three different modes to render an op class, depending on the // number and type of outputs it has: diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h index 759d800ecf..05decd6b54 100644 --- a/tensorflow/java/src/gen/cc/op_generator.h +++ b/tensorflow/java/src/gen/cc/op_generator.h @@ -19,10 +19,10 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/java/src/gen/cc/op_specs.h" namespace tensorflow { diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc index 63e99fbb04..941ab2699c 100644 --- a/tensorflow/java/src/gen/cc/op_specs.cc +++ b/tensorflow/java/src/gen/cc/op_specs.cc @@ -14,9 +14,9 @@ limitations under the License. ==============================================================================*/ #include <map> -#include <vector> #include <string> #include <utility> +#include <vector> #include "re2/re2.h" #include "tensorflow/core/framework/op.h" @@ -50,7 +50,7 @@ class TypeResolver { // For example, if the argument's datatype is DT_STRING, this method will // return "java.lang.String", so the argument can become "Operand<String>" // in the Ops API - Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out); + Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out); // Returns types of an input attribute // @@ -62,7 +62,7 @@ class TypeResolver { // <java.lang.Float, float>, so the attribute can be used as a "Float" object // in the Ops API and casted to a "float" when passing through the JNI layer. std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def, - bool *iterable_out); + bool* iterable_out); // Returns true if the type of this attribute has already been resolved bool IsAttributeVisited(const string& attr_name) { @@ -89,8 +89,7 @@ class TypeResolver { } }; -Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, - bool* iterable_out) { +Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) { *iterable_out = false; if (!arg_def.number_attr().empty()) { // when number_attr is set, argument has to be a list of tensors @@ -154,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, } else { LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name() - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } return type; } std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, - bool* iterable_out) { + bool* iterable_out) { std::pair<Type, Type> types = MakeTypePair(Type::Wildcard()); *iterable_out = false; StringPiece attr_type = attr_def.type(); @@ -185,7 +184,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else if (attr_type == "tensor") { types = MakeTypePair(Type::Class("Tensor", "org.tensorflow") - .add_parameter(Type::Wildcard())); + .add_parameter(Type::Wildcard())); } else if (attr_type == "type") { Type type = *iterable_out ? Type::Wildcard() : NextGeneric(); @@ -196,7 +195,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def, } else { LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type - << "\" in operation \"" << op_def_.name() << "\""; + << "\" in operation \"" << op_def_.name() << "\""; } visited_attrs_.insert(std::make_pair(attr_def.name(), types.first)); return types; @@ -219,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) { return result; } -bool FindAndCut(re2::StringPiece* input, const RE2& expr, - re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) { - re2::StringPiece match; - if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) { - return false; - } - before_match->set(input->data(), match.begin() - input->begin()); - input->remove_prefix(match.end() - before_match->begin()); - if (ret_match != nullptr) { - *ret_match = match; - } +bool FindAndCut(string* input, const RE2& expr, string* before_match, + string* ret_match = nullptr) { + string match; + if (!RE2::PartialMatch(*input, expr, &match)) return false; + *before_match = input->substr(0, input->find(match)); + *input = input->substr(before_match->size() + match.size()); + if (ret_match != nullptr) *ret_match = match; return true; } -string ParseDocumentation(re2::StringPiece input) { +string ParseDocumentation(const string& inp) { std::stringstream javadoc_text; // TODO(karllessard) This is a very minimalist utility method for converting // markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check // for alternatives to increase the level of support for markups. std::vector<string> markups_subexpr; - markups_subexpr.push_back("\n+\\*\\s+"); // lists - markups_subexpr.push_back("\n{2,}"); // paragraphs + markups_subexpr.push_back("\n+\\*\\s+"); // lists + markups_subexpr.push_back("\n{2,}"); // paragraphs markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks - markups_subexpr.push_back("`+"); // inlined code and code blocks + markups_subexpr.push_back("`+"); // inlined code and code blocks markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis - markups_subexpr.push_back("\\["); // hyperlinks - const RE2 markup_expr(str_util::Join(markups_subexpr, "|")); + markups_subexpr.push_back("\\["); // hyperlinks + const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")"); bool in_list = false; + string input = inp; while (true) { - re2::StringPiece text; - re2::StringPiece markup; + string text, markup; if (!FindAndCut(&input, markup_expr, &text, &markup)) { javadoc_text << input; break; // end of loop } javadoc_text << text; - if (markup.starts_with("\n")) { + if (str_util::StartsWith(markup, "\n")) { javadoc_text << "\n"; - if (markup.contains("*")) { + if (str_util::StrContains(markup, "*")) { // new list item javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n"; in_list = true; @@ -267,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) { // end of list javadoc_text << "</li>\n</ul>\n"; in_list = false; - } else if (!input.starts_with("```")) { + } else if (!str_util::StartsWith(input, "```")) { // new paragraph (not required if a <pre> block follows) javadoc_text << "<p>\n"; } - } else if (markup.starts_with("```")) { + } else if (str_util::StartsWith(markup, "```")) { // code blocks - if (FindAndCut(&input, "```\\s*\n*", &text)) { + if (FindAndCut(&input, "(```\\s*\n*)", &text)) { javadoc_text << "<pre>{@code\n" << text << "}</pre>\n"; } else { javadoc_text << markup; } - } else if (markup.starts_with("`")) { + } else if (str_util::StartsWith("(" + markup + ")", "`")) { // inlined code if (FindAndCut(&input, markup, &text)) { javadoc_text << "{@code " << text << "}"; @@ -287,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) { } } else if (markup == "**") { // text emphasis (strong) - if (FindAndCut(&input, "\\b\\*{2}", &text)) { + if (FindAndCut(&input, "(\\b\\*{2})", &text)) { javadoc_text << "<b>" << ParseDocumentation(text) << "</b>"; } else { javadoc_text << markup; } } else if (markup == "*") { // text emphasis (normal) - if (FindAndCut(&input, "\\b\\*{1}", &text)) { + if (FindAndCut(&input, "(\\b\\*{1})", &text)) { javadoc_text << "<i>" << ParseDocumentation(text) << "</i>"; } else { javadoc_text << markup; } - } else if (markup.starts_with("[")) { + } else if (str_util::StartsWith(markup, "[")) { // hyperlinks string label; string link; - if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) { + if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label, + &link) && + str_util::StartsWith(input, label + link)) { + input = input.substr(label.size() + link.size()); javadoc_text << "<a href=\"" << link << "\">" - << ParseDocumentation(label) - << "</a>"; + << ParseDocumentation(label) << "</a>"; } else { javadoc_text << markup; } @@ -319,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) { } ArgumentSpec CreateInput(const OpDef_ArgDef& input_def, - const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) { + const ApiDef::Arg& input_api_def, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(input_def, &iterable); - Type var_type = Type::Interface("Operand", "org.tensorflow") - .add_parameter(type); + Type var_type = + Type::Interface("Operand", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::IterableOf(var_type); } - return ArgumentSpec(input_api_def.name(), + return ArgumentSpec( + input_api_def.name(), Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type), - type, - ParseDocumentation(input_api_def.description()), - iterable); + type, ParseDocumentation(input_api_def.description()), iterable); } AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def, - const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) { + const ApiDef::Attr& attr_api_def, + TypeResolver* type_resolver) { bool iterable = false; std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable); - Type var_type = types.first.kind() == Type::GENERIC ? - Type::Class("Class").add_parameter(types.first) : types.first; + Type var_type = types.first.kind() == Type::GENERIC + ? Type::Class("Class").add_parameter(types.first) + : types.first; if (iterable) { var_type = Type::ListOf(var_type); } - return AttributeSpec(attr_api_def.name(), + return AttributeSpec( + attr_api_def.name(), Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type), - types.first, - types.second, - ParseDocumentation(attr_api_def.description()), - iterable, - attr_api_def.has_default_value()); + types.first, types.second, ParseDocumentation(attr_api_def.description()), + iterable, attr_api_def.has_default_value()); } ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def, - const ApiDef::Arg& output_api, TypeResolver* type_resolver) { + const ApiDef::Arg& output_api, + TypeResolver* type_resolver) { bool iterable = false; Type type = type_resolver->TypeOf(output_def, &iterable); - Type var_type = Type::Class("Output", "org.tensorflow") - .add_parameter(type); + Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type); if (iterable) { var_type = Type::ListOf(var_type); } - return ArgumentSpec(output_api.name(), + return ArgumentSpec( + output_api.name(), Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type), - type, - ParseDocumentation(output_api.description()), - iterable); + type, ParseDocumentation(output_api.description()), iterable); } EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, - const ApiDef_Endpoint& endpoint_def) { + const ApiDef_Endpoint& endpoint_def) { std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), "."); string package; string name; @@ -377,27 +373,25 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def, package = name_tokens.at(0); name = name_tokens.at(1); } else { - package = kDefaultEndpointPackage; + package = "core"; // generate unclassified ops in the 'core' package name = name_tokens.at(0); } - return EndpointSpec(package, - name, - Javadoc::Create(ParseDocumentation(api_def.summary())) - .details(ParseDocumentation(api_def.description()))); + return EndpointSpec(package, name, + Javadoc::Create(ParseDocumentation(api_def.summary())) + .details(ParseDocumentation(api_def.description()))); } } // namespace OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { - OpSpec op(api_def.graph_op_name(), - api_def.visibility() == ApiDef::HIDDEN, - op_def.deprecation().explanation()); + OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN, + op_def.deprecation().explanation()); TypeResolver type_resolver(op_def); for (const string& next_input_name : api_def.arg_order()) { for (int i = 0; i < op_def.input_arg().size(); ++i) { if (op_def.input_arg(i).name() == next_input_name) { op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i), - &type_resolver)); + &type_resolver)); break; } } @@ -406,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { // do not parse attributes already visited, they have probably been inferred // before as an input argument type if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) { - AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i), - &type_resolver); + AttributeSpec attr = + CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver); // attributes with a default value are optional if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) { op.optional_attributes_.push_back(attr); @@ -417,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) { } } for (int i = 0; i < op_def.output_arg().size(); ++i) { - op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i), - &type_resolver)); + op.outputs_.push_back( + CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver)); } for (const auto& endpoint_def : api_def.endpoint()) { op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def)); diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h index 3b53c730df..30ecb8ce53 100644 --- a/tensorflow/java/src/gen/cc/op_specs.h +++ b/tensorflow/java/src/gen/cc/op_specs.h @@ -19,9 +19,9 @@ limitations under the License. #include <string> #include <vector> -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/java/src/gen/cc/java_defs.h" namespace tensorflow { @@ -38,9 +38,8 @@ class EndpointSpec { // javadoc: the endpoint class documentation // TODO(annarev): hardcode depcreated to false until deprecated is possible EndpointSpec(const string& package, const string& name, - const Javadoc& javadoc) - : package_(package), name_(name), javadoc_(javadoc), - deprecated_(false) {} + const Javadoc& javadoc) + : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {} const string& package() const { return package_; } const string& name() const { return name_; } @@ -63,10 +62,13 @@ class ArgumentSpec { // type: the tensor type of this argument // description: a description of this argument, in javadoc // iterable: true if this argument is a list - ArgumentSpec(const string& op_def_name, const Variable& var, - const Type& type, const string& description, bool iterable) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable) {} + ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type, + const string& description, bool iterable) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -94,11 +96,16 @@ class AttributeSpec { // iterable: true if this attribute is a list // has_default_value: true if this attribute has a default value if not set AttributeSpec(const string& op_def_name, const Variable& var, - const Type& type, const Type& jni_type, const string& description, - bool iterable, bool has_default_value) - : op_def_name_(op_def_name), var_(var), type_(type), - description_(description), iterable_(iterable), - jni_type_(jni_type), has_default_value_(has_default_value) {} + const Type& type, const Type& jni_type, + const string& description, bool iterable, + bool has_default_value) + : op_def_name_(op_def_name), + var_(var), + type_(type), + description_(description), + iterable_(iterable), + jni_type_(jni_type), + has_default_value_(has_default_value) {} const string& op_def_name() const { return op_def_name_; } const Variable& var() const { return var_; } @@ -147,9 +154,10 @@ class OpSpec { // hidden: true if this op should not be visible through the Graph Ops API // deprecation_explanation: message to show if all endpoints are deprecated explicit OpSpec(const string& graph_op_name, bool hidden, - const string& deprecation_explanation) - : graph_op_name_(graph_op_name), hidden_(hidden), - deprecation_explanation_(deprecation_explanation) {} + const string& deprecation_explanation) + : graph_op_name_(graph_op_name), + hidden_(hidden), + deprecation_explanation_(deprecation_explanation) {} const string graph_op_name_; const bool hidden_; diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 3524160d87..796d6a62dc 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -15,6 +15,18 @@ limitations under the License. package org.tensorflow.processor; +import com.google.common.base.CaseFormat; +import com.google.common.base.Strings; +import com.google.common.collect.HashMultimap; +import com.google.common.collect.Multimap; +import com.squareup.javapoet.ClassName; +import com.squareup.javapoet.FieldSpec; +import com.squareup.javapoet.JavaFile; +import com.squareup.javapoet.MethodSpec; +import com.squareup.javapoet.ParameterSpec; +import com.squareup.javapoet.TypeName; +import com.squareup.javapoet.TypeSpec; +import com.squareup.javapoet.TypeVariableName; import java.io.IOException; import java.util.Collection; import java.util.Collections; @@ -23,7 +35,6 @@ import java.util.Map; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; - import javax.annotation.processing.AbstractProcessor; import javax.annotation.processing.Filer; import javax.annotation.processing.Messager; @@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter; import javax.lang.model.util.Elements; import javax.tools.Diagnostic.Kind; -import com.google.common.base.CaseFormat; -import com.google.common.base.Strings; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.Multimap; -import com.squareup.javapoet.ClassName; -import com.squareup.javapoet.FieldSpec; -import com.squareup.javapoet.JavaFile; -import com.squareup.javapoet.MethodSpec; -import com.squareup.javapoet.ParameterSpec; -import com.squareup.javapoet.TypeName; -import com.squareup.javapoet.TypeSpec; -import com.squareup.javapoet.TypeVariableName; - /** * A compile-time Processor that aggregates classes annotated with {@link * org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please @@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor { // generated our code, flag the location of each such class. if (hasRun) { for (Element e : annotated) { - error(e, "The Operator processor has already processed @Operator annotated sources\n" + - "and written out an Ops API. It cannot process additional @Operator sources.\n" + - "One reason this can happen is if other annotation processors generate\n" + - "new @Operator source files."); + error( + e, + "The Operator processor has already processed @Operator annotated sources\n" + + "and written out an Ops API. It cannot process additional @Operator sources.\n" + + "One reason this can happen is if other annotation processors generate\n" + + "new @Operator source files."); } return true; } @@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor { return Collections.singleton("org.tensorflow.op.annotation.Operator"); } - private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); + private static final Pattern JAVADOC_TAG_PATTERN = + Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*"); private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops"); - private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator"); + private static final TypeName T_OPERATOR = + ClassName.get("org.tensorflow.op.annotation", "Operator"); private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope"); private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph"); private static final TypeName T_STRING = ClassName.get(String.class); @@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor { private void write(TypeSpec spec) { try { - JavaFile.builder("org.tensorflow.op", spec) - .skipJavaLangImports(true) - .build() - .writeTo(filer); + JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer); } catch (IOException e) { throw new AssertionError(e); } } private void writeApi(Multimap<String, MethodSpec> groupedMethods) { - Map<String, ClassName> groups = new HashMap<String, ClassName>(); - + Map<String, ClassName> groups = new HashMap<>(); + // Generate a API class for each group collected other than the default one (= empty string) - for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) { + for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) { if (!entry.getKey().isEmpty()) { TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue()); write(groupClass); @@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor { } private boolean collectOpsMethods( - RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) { + RoundEnvironment roundEnv, + Multimap<String, MethodSpec> groupedMethods, + TypeElement annotation) { boolean result = true; for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) { // @Operator can only apply to types, so e must be a TypeElement. if (!(e instanceof TypeElement)) { - error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString()); + error( + e, + "@Operator can only be applied to classes, but this is a %s", + e.getKind().toString()); result = false; continue; } @@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor { } return result; } - - private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { + + private void collectOpMethods( + Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) { AnnotationMirror am = getAnnotationMirror(opClass, annotation); String groupName = getAnnotationElementValueAsString("group", am); String methodName = getAnnotationElementValueAsString("name", am); ClassName opClassName = ClassName.get(opClass); if (Strings.isNullOrEmpty(methodName)) { - methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); + methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName()); } - // Build a method for each @Operator found in the class path. There should be one method per operation factory called + // Build a method for each @Operator found in the class path. There should be one method per + // operation factory called // "create", which takes in parameter a scope and, optionally, a list of arguments for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) { - if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) { + if (opMethod.getModifiers().contains(Modifier.STATIC) + && opMethod.getSimpleName().contentEquals("create")) { MethodSpec method = buildOpMethod(methodName, opClassName, opMethod); groupedMethods.put(groupName, method); } } } - private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) { + private MethodSpec buildOpMethod( + String methodName, ClassName opClassName, ExecutableElement factoryMethod) { MethodSpec.Builder builder = MethodSpec.methodBuilder(methodName) - .addModifiers(Modifier.PUBLIC) - .returns(TypeName.get(factoryMethod.getReturnType())) - .varargs(factoryMethod.isVarArgs()) - .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); + .addModifiers(Modifier.PUBLIC) + .returns(TypeName.get(factoryMethod.getReturnType())) + .varargs(factoryMethod.isVarArgs()) + .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod)); - for (TypeParameterElement tp: factoryMethod.getTypeParameters()) { + for (TypeParameterElement tp : factoryMethod.getTypeParameters()) { TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType()); builder.addTypeVariable(tvn); } - for (TypeMirror thrownType: factoryMethod.getThrownTypes()) { + for (TypeMirror thrownType : factoryMethod.getThrownTypes()) { builder.addException(TypeName.get(thrownType)); } StringBuilder call = new StringBuilder("return $T.create(scope"); @@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor { call.append(")"); builder.addStatement(call.toString(), opClassName); return builder.build(); - } - + } + private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) { StringBuilder javadoc = new StringBuilder(); - javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n"); + javadoc + .append("Adds an {@link ") + .append(opClassName.simpleName()) + .append("} operation to the graph\n\n"); - // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the + // Add all javadoc tags found in the operator factory method but the first one, which should be + // in all cases the // 'scope' parameter that is implicitly passed by this API Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod)); boolean firstParam = true; @@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor { } else { javadoc.append(tag).append('\n'); } - } + } javadoc.append("@see {@link ").append(opClassName).append("}\n"); return javadoc.toString(); } - + private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope"); - + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope"); + TypeSpec.Builder builder = TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + - "@see {@link $T}\n", group, T_GRAPH, T_OPS) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for adding {@code $L} operations to a {@link $T Graph}\n\n" + + "@see {@link $T}\n", + group, + T_GRAPH, + T_OPS) + .addMethods(methods) + .addMethod(ctorBuilder.build()); builder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); return builder.build(); } - private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { + private static TypeSpec buildTopClass( + Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) { MethodSpec.Builder ctorBuilder = MethodSpec.constructorBuilder() - .addModifiers(Modifier.PRIVATE) - .addParameter(T_SCOPE, "scope") - .addStatement("this.scope = scope", T_SCOPE); + .addModifiers(Modifier.PRIVATE) + .addParameter(T_SCOPE, "scope") + .addStatement("this.scope = scope", T_SCOPE); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue()); } TypeSpec.Builder opsBuilder = TypeSpec.classBuilder("Ops") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" + - "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" + - "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + - "try (Graph g = new Graph()) {\n" + - " Ops ops = new Ops(g);\n" + - " // Operations are typed classes with convenience\n" + - " // builders in Ops.\n" + - " Constant three = ops.constant(3);\n" + - " // Single-result operations implement the Operand\n" + - " // interface, so this works too.\n" + - " Operand four = ops.constant(4);\n" + - " // Most builders are found within a group, and accept\n" + - " // Operand types as operands\n" + - " Operand nine = ops.math().add(four, ops.constant(5));\n" + - " // Multi-result operations however offer methods to\n" + - " // select a particular result for use.\n" + - " Operand result = \n" + - " ops.math().add(ops.array().unique(s, a).y(), b);\n" + - " // Optional attributes\n" + - " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + - " // Naming operators\n" + - " ops.withName(“foo”).constant(5); // name “foo”\n" + - " // Names can exist in a hierarchy\n" + - " Ops sub = ops.withSubScope(“sub”);\n" + - " sub.withName(“bar”).constant(4); // “sub/bar”\n" + - "}\n" + - "}</pre>\n", T_GRAPH, T_OPERATOR) - .addMethods(methods) - .addMethod(ctorBuilder.build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .addJavadoc( + "An API for building a {@link $T} with operation wrappers\n<p>\n" + + "Any operation wrapper found in the classpath properly annotated as an" + + "{@link $T @Operator} is exposed\n" + + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" + + "try (Graph g = new Graph()) {\n" + + " Ops ops = new Ops(g);\n" + + " // Operations are typed classes with convenience\n" + + " // builders in Ops.\n" + + " Constant three = ops.constant(3);\n" + + " // Single-result operations implement the Operand\n" + + " // interface, so this works too.\n" + + " Operand four = ops.constant(4);\n" + + " // Most builders are found within a group, and accept\n" + + " // Operand types as operands\n" + + " Operand nine = ops.math().add(four, ops.constant(5));\n" + + " // Multi-result operations however offer methods to\n" + + " // select a particular result for use.\n" + + " Operand result = \n" + + " ops.math().add(ops.array().unique(s, a).y(), b);\n" + + " // Optional attributes\n" + + " ops.math().matMul(a, b, MatMul.transposeA(true));\n" + + " // Naming operators\n" + + " ops.withName(“foo”).constant(5); // name “foo”\n" + + " // Names can exist in a hierarchy\n" + + " Ops sub = ops.withSubScope(“sub”);\n" + + " sub.withName(“bar”).constant(4); // “sub/bar”\n" + + "}\n" + + "}</pre>\n", + T_GRAPH, + T_OPERATOR) + .addMethods(methods) + .addMethod(ctorBuilder.build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withSubScope") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "childScopeName") - .returns(T_OPS) - .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) - .addJavadoc( - "Returns an API that adds operations to the graph with the provided name prefix.\n\n" + - "@see {@link $T#withSubScope(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "childScopeName") + .returns(T_OPS) + .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS) + .addJavadoc( + "Returns an API that adds operations to the graph with the provided name prefix.\n" + + "\n@see {@link $T#withSubScope(String)}\n", + T_SCOPE) + .build()); opsBuilder.addMethod( MethodSpec.methodBuilder("withName") - .addModifiers(Modifier.PUBLIC) - .addParameter(T_STRING, "opName") - .returns(T_OPS) - .addStatement("return new Ops(scope.withName(opName))") - .addJavadoc( - "Returns an API that uses the provided name for an op.\n\n" + - "@see {@link $T#withName(String)}\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC) + .addParameter(T_STRING, "opName") + .returns(T_OPS) + .addStatement("return new Ops(scope.withName(opName))") + .addJavadoc( + "Returns an API that uses the provided name for an op.\n\n" + + "@see {@link $T#withName(String)}\n", + T_SCOPE) + .build()); opsBuilder.addField( - FieldSpec.builder(T_SCOPE, "scope") - .addModifiers(Modifier.PRIVATE, Modifier.FINAL) - .build()); + FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build()); opsBuilder.addMethod( MethodSpec.methodBuilder("scope") - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(T_SCOPE) - .addStatement("return scope") - .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(T_SCOPE) + .addStatement("return scope") + .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE) + .build()); - for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) { + for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) { opsBuilder.addField( FieldSpec.builder(entry.getValue(), entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .build()); - + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .build()); + opsBuilder.addMethod( MethodSpec.methodBuilder(entry.getKey()) - .addModifiers(Modifier.PUBLIC, Modifier.FINAL) - .returns(entry.getValue()) - .addStatement("return $L", entry.getKey()) - .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.FINAL) + .returns(entry.getValue()) + .addStatement("return $L", entry.getKey()) + .addJavadoc( + "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey()) + .build()); } opsBuilder.addMethod( MethodSpec.methodBuilder("create") - .addModifiers(Modifier.PUBLIC, Modifier.STATIC) - .addParameter(T_GRAPH, "graph") - .returns(T_OPS) - .addStatement("return new Ops(new $T(graph))", T_SCOPE) - .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") - .build()); + .addModifiers(Modifier.PUBLIC, Modifier.STATIC) + .addParameter(T_GRAPH, "graph") + .returns(T_OPS) + .addStatement("return new Ops(new $T(graph))", T_SCOPE) + .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n") + .build()); return opsBuilder.build(); } @@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor { return am; } } - throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element " - + element.getSimpleName()); + throw new IllegalArgumentException( + "Annotation " + + annotation.getSimpleName() + + " not present on element " + + element.getSimpleName()); } - + private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) { - for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) { + for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : + am.getElementValues().entrySet()) { if (entry.getKey().getSimpleName().contentEquals(elementName)) { return entry.getValue().getValue().toString(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d4fd3db5f7..7d19696749 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable { } } + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * <p> + * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function + * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. + * <p> + * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all + * shapes in {@code y}. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + Output<?>[] dy = new Output<?>[x.length]; + final long[] yHandles = new long[y.length]; + final int[] yIndices = new int[y.length]; + final long[] xHandles = new long[x.length]; + final int[] xIndices = new int[x.length]; + long[] dxHandles = null; + int[] dxIndices = null; + + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; + + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles + // of the gradient operations while the second holds the index of their output + // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); + } + } + return dy; + } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + * <p> + * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output and {@code dx} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { + return addGradients(new Output<?>[]{y}, x, null); + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); + private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java new file mode 100644 index 0000000000..13bc463e7d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow; + +/** + * Interface implemented by operands of a TensorFlow operation. + * + * <p>Example usage: + * + * <pre>{@code + * // The "decodeJpeg" operation can be used as input to the "cast" operation + * Input decodeJpeg = ops.image().decodeJpeg(...); + * ops.math().cast(decodeJpeg, DataType.FLOAT); + * + * // The output "y" of the "unique" operation can be used as input to the "cast" operation + * Output y = ops.array().unique(...).y(); + * ops.math().cast(y, DataType.FLOAT); + * + * // The "split" operation can be used as input list to the "concat" operation + * Iterable<? extends Input> split = ops.array().split(...); + * ops.array().concat(0, split); + * }</pre> + */ +public interface Input<T> { + + /** + * Returns the symbolic handle of a tensor. + * + * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is + * used to obtain a symbolic handle that represents the computation of the input. + * + * @see OperationBuilder#addInput(Output) + */ + Output<T> asOutput(); +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java new file mode 100644 index 0000000000..f4671c8af9 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * <p> + * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss + * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}. + * <p> + * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all + * shapes in {@code y}. + * <p> + * The partial derivatives are returned in output {@code dy}, with the size of {@code x}. + * <p> + * Example of usage: + * <pre>{@code + * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b)); + * + * Constant<Float> alpha = ops.constant(1.0f, Float.class); + * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0)); + * ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1)); + * }</pre> + */ +@Operator +public class Gradients implements Op, Iterable<Operand<?>> { + + /** + * Optional attributes for {@link Gradients} + */ + public static class Options { + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return this option builder + */ + public Options dx(Iterable<Operand<?>> dx) { + this.dx = dx; + return this; + } + + private Iterable<Operand<?>> dx; + + private Options() { + } + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * @param scope current graph scope + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + Output<?>[] dx = null; + if (options != null) { + for (Options opts : options) { + if (opts.dx != null) { + dx = Operands.asOutputs(opts.dx); + } + } + } + Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(gradOutputs)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is + * a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); + } + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return builder to add more options to this operation + */ + public Options dx(Iterable<Operand<?>> dx) { + return new Options().dx(dx); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator<Operand<?>> iterator() { + return (Iterator) dy.iterator(); + } + + /** + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} + */ + public List<Output<?>> dy() { + return dy; + } + + /** + * Returns a symbolic handle to one of the gradient operation output + * <p> + * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code + * gradients.<Integer>dy(0)} + * + * @param <T> The expected element type of the tensors produced by this output. + * @param index The index of the output among the gradients added by this operation + */ + @SuppressWarnings("unchecked") + public <T> Output<T> dy(int index) { + return (Output<T>) dy.get(index); + } + + private List<Output<?>> dy; + + private Gradients(List<Output<?>> dy) { + this.dy = dy; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java new file mode 100644 index 0000000000..ab34f6aa12 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a boolean. */ +public class TFBool implements TFType { + private TFBool() {} + static { + Types.typeCodes.put(TFBool.class, DataType.BOOL); + } + static { + Types.scalars.put(TFBool.class, false); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java new file mode 100644 index 0000000000..49e5d9f2f3 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 64-bit double precision floating point number. */ +public class TFDouble implements TFType { + private TFDouble() {} + static { + Types.typeCodes.put(TFDouble.class, DataType.DOUBLE); + } + static { + Types.scalars.put(TFDouble.class, 0.0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java new file mode 100644 index 0000000000..8426ee41f0 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 32-bit single precision floating point number. */ +public class TFFloat implements TFType { + private TFFloat() {} + static { + Types.typeCodes.put(TFFloat.class, DataType.FLOAT); + } + static { + Types.scalars.put(TFFloat.class, 0f); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java new file mode 100644 index 0000000000..3947b6ad09 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 32-bit signed integer. */ +public class TFInt32 implements TFType { + private TFInt32() {} + static { + Types.typeCodes.put(TFInt32.class, DataType.INT32); + } + static { + Types.scalars.put(TFInt32.class, 0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java new file mode 100644 index 0000000000..ccdded8693 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents a 64-bit signed integer. */ +public class TFInt64 implements TFType { + private TFInt64() {} + static { + Types.typeCodes.put(TFInt64.class, DataType.INT64); + } + static { + Types.scalars.put(TFInt64.class, 0L); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java new file mode 100644 index 0000000000..e7327e8c57 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java @@ -0,0 +1,27 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents an arbitrary sequence of bytes. */ +public class TFString implements TFType { + private TFString() {} + static { + Types.typeCodes.put(TFString.class, DataType.STRING); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java new file mode 100644 index 0000000000..562953ac9d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java @@ -0,0 +1,20 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.types; + +/** + * A marker interface for classes representing TensorFlow types. + */ +public interface TFType {} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java new file mode 100644 index 0000000000..d7305ca5a8 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java @@ -0,0 +1,30 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// GENERATED FILE. To update, edit tftypes.pl instead. + +package org.tensorflow.types; + +import org.tensorflow.DataType; + +/** Represents an 8-bit unsigned integer. */ +public class TFUInt8 implements TFType { + private TFUInt8() {} + static { + Types.typeCodes.put(TFUInt8.class, DataType.UINT8); + } + static { + Types.scalars.put(TFUInt8.class, (byte)0); + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java new file mode 100644 index 0000000000..976cd9fd34 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java @@ -0,0 +1,52 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.types; + +import java.util.HashMap; +import java.util.Map; +import org.tensorflow.DataType; + +/** + * Utility class for managing the representation of TensorFlow types as Java + * types. For each TensorFlow type (e.g., int32), there is a corresponding Java + * type (e.g., TFInt32) that represents it at compile time and a corresponding + * class object (e.g., TFInt32.class) that represents it at run time. There is + * also an enumeration value in DataType that can be used to represent the + * type, though that should rarely be required. + */ +public class Types { + + private Types() {} // not instantiable + + static final Map<Class<?>, DataType> typeCodes = new HashMap<>(); + + /** Returns the DataType value corresponding to a TensorFlow type class. */ + public static DataType dataType(Class<? extends TFType> c) { + DataType dtype = typeCodes.get(c); + if (dtype == null) { + throw new IllegalArgumentException("" + c + " is not a TensorFlow type."); + } + return dtype; + } + + static final Map<Class<?>, Object> scalars = new HashMap<>(); + + /** Returns the zero value of type described by {@code c}, or null if + * the type (e.g., string) is not numeric and therefore has no zero value. + */ + public static Object zeroValue(Class<? extends TFType> c) { + return scalars.get(c); + } +} diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 0fef155275..dac6a345e9 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/java/src/main/native/graph_jni.h" #include <limits> +#include <memory> #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" namespace { @@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, + jlongArray y_handles, jintArray y_indices, + jlongArray x_handles, jintArray x_indices, + jlongArray dx_handles, jintArray dx_indices) { + + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return nullptr; + + const jint ny = env->GetArrayLength(y_handles); + const jint nx = env->GetArrayLength(x_handles); + + std::unique_ptr<TF_Output[]> y(new TF_Output[ny]); + std::unique_ptr<TF_Output[]> x(new TF_Output[nx]); + std::unique_ptr<TF_Output[]> dx(nullptr); + std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]); + + resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); + resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); + if (dx_handles != nullptr) { + if (env->GetArrayLength(dx_handles) != ny) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d dx handles", ny, + env->GetArrayLength(dx_handles)); + } + dx.reset(new TF_Output[ny]); + resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); + } + if (env->ExceptionCheck()) return nullptr; + + TF_Status* status = TF_NewStatus(); + TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); + + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + + // returned array contains both op handles and output indices, in pair + jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); + jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); + for (int i = 0, j = nx; i < nx; ++i, ++j) { + TF_Output dy_output = dy.get()[i]; + dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper); + dy_elems[j] = static_cast<jlong>(dy_output.index); + } + env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); + + return dy_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index dd2e038332..4f87e8d5a7 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: name + * Signature: (J[J[I[J[I[J[I)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, + jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, + jintArray); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc index 708983fef5..8b11525785 100644 --- a/tensorflow/java/src/main/native/session_jni.cc +++ b/tensorflow/java/src/main/native/session_jni.cc @@ -17,6 +17,7 @@ limitations under the License. #include <memory> #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" #include "tensorflow/java/src/main/native/session_jni.h" @@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array, env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT); } -void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, - jintArray src_index, TF_Output* dst, jint n) { - if (env->ExceptionCheck()) return; - jint len = env->GetArrayLength(src_op); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operations", n, len, type); - return; - } - len = env->GetArrayLength(src_index); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operation output indices", n, len, - type); - return; - } - jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); - jint* indices = env->GetIntArrayElements(src_index, nullptr); - for (int i = 0; i < n; ++i) { - if (op_handles[i] == 0) { - throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, - i, n); - break; - } - dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]), - static_cast<int>(indices[i])}; - } - env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); - env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); -} - void TF_MaybeDeleteBuffer(TF_Buffer* buf) { if (buf == nullptr) return; TF_DeleteBuffer(buf); diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc new file mode 100644 index 0000000000..069ac05a1c --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/utils_jni.h" + +#include "tensorflow/java/src/main/native/exception_jni.h" + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n) { + if (env->ExceptionCheck()) return; + jint len = env->GetArrayLength(src_op); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operations", n, len, type); + return; + } + len = env->GetArrayLength(src_index); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operation output indices", n, len, + type); + return; + } + jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); + jint* indices = env->GetIntArrayElements(src_index, nullptr); + for (int i = 0; i < n; ++i) { + if (op_handles[i] == 0) { + throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, + i, n); + break; + } + dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]), + static_cast<int>(indices[i])}; + } + env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); + env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); +} + + + + diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h new file mode 100644 index 0000000000..352298e7de --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_ +#define TENSORFLOW_JAVA_UTILS_JNI_H_ + +#include <jni.h> + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdc..c2e52c22c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -129,4 +130,106 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x1); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + Output<Float> y2 = TestUtil.addN(g, y0, x2); + + Output<?>[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor<Float> c1 = Tensors.create(3.0f); + Tensor<Float> c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Output<?>[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor<Float> c = Tensors.create(3.0f); + Tensor<?> output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + private static Output<?>[] toArray(Output<?>... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6..7d5980bcde 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output<Integer> feed = g.operation("X").output(0); Output<Integer> fetch = g.operation("Y").output(0); try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs); + TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList<Tensor<?>> outputs = - new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> - implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8..4e84886416 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> + implements AutoCloseable { + AutoCloseableList(Collection<? extends E> c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static <T> Output<T> constant(Graph g, String name, Object value) { try (Tensor<?> t = Tensor.create(value)) { return g.opBuilder("Const", name) @@ -36,7 +61,7 @@ public class TestUtil { .<T>output(0); } - public static Output<?> addN(Graph g, Output<?>... inputs) { + public static <T> Output<T> addN(Graph g, Output<?>... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +83,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static <T> Output<T> square(Graph g, String name, Output<T> value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .<T>output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output<Integer> aa = constant(g, "A", a); |