aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml168
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/run_inside_container.sh5
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml323
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/gen/cc/java_defs.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.cc29
-rw-r--r--tensorflow/java/src/gen/cc/op_generator.h2
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.cc148
-rw-r--r--tensorflow/java/src/gen/cc/op_specs.h40
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java296
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java79
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Input.java48
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java153
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFString.java27
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFType.java20
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/Types.java52
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc54
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h9
-rw-r--r--tensorflow/java/src/main/native/session_jni.cc32
-rw-r--r--tensorflow/java/src/main/native/utils_jni.cc53
-rw-r--r--tensorflow/java/src/main/native/utils_jni.h33
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java103
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java38
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java34
35 files changed, 1596 insertions, 344 deletions
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
index 0642be06fa..a160377210 100644
--- a/tensorflow/java/maven/hadoop/pom.xml
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -1,12 +1,30 @@
-<project
- xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <!-- Placeholder pom which is replaced by TensorFlow ecosystem Hadoop pom during build -->
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
- <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+ <groupId>org.tensorflow</groupId>
<artifactId>hadoop</artifactId>
<packaging>jar</packaging>
+ <version>1.9.0-rc2</version>
+ <name>tensorflow-hadoop</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
+
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <maven.compiler.source>1.6</maven.compiler.source>
+ <maven.compiler.target>1.6</maven.compiler.target>
+ <hadoop.version>2.6.0</hadoop.version>
+ <protobuf.version>3.3.1</protobuf.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <licenses>
+ <license>
+ <name>Apache License Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ </license>
+ </licenses>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
@@ -14,11 +32,133 @@
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
- <url>https://github.com/tensorflow/ecosystem/</url>
- <parent>
- <groupId>org.tensorflow</groupId>
- <artifactId>parentpom</artifactId>
- <version>1.9.0-rc0</version>
- <relativePath>../</relativePath>
- </parent>
-</project> \ No newline at end of file
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ </build>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>proto</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-common</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-core</artifactId>
+ <version>${hadoop.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-mapreduce-client-jobclient</artifactId>
+ <version>${hadoop.version}</version>
+ <type>test-jar</type>
+ <optional>true</optional>
+ <scope>test</scope>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ </dependencies>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profiles>
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+</project>
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index a7fa9ea5cc..489fb7bdb0 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 83aae29f1e..5bef85f75e 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 50bd8ee5f9..8d93c78220 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index b4746794ea..c5861102c8 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 618a2a124c..754caad900 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 2e771064e4..2240d6b7b9 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -203,7 +203,10 @@ download_tf_ecosystem() {
cd "${ECOSYSTEM_DIR}"
git clone "${TF_ECOSYSTEM_URL}"
cd ecosystem
- git checkout r${TF_VERSION}
+ # TF_VERSION is a semver string (<major>.<minor>.<patch>[-suffix])
+ # but the branch is just (r<major>.<minor>).
+ RELEASE_BRANCH=$(echo "${TF_VERSION}" | sed -e 's/\([0-9]\+\.[0-9]\+\)\.[0-9]\+.*/\1/')
+ git checkout r${RELEASE_BRANCH}
# Copy the TensorFlow Hadoop source
cp -r "${ECOSYSTEM_DIR}/ecosystem/hadoop/src" "${HADOOP_DIR}"
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
index 19c752d08b..99237fdb98 100644
--- a/tensorflow/java/maven/spark-connector/pom.xml
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -1,12 +1,23 @@
-<project
- xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <!-- Placeholder pom which is replaced by TensorFlow ecosystem Spark pom during build -->
+<?xml version="1.0" encoding="UTF-8"?>
+<project xmlns="http://maven.apache.org/POM/4.0.0"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
- <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
- <artifactId>spark-connector</artifactId>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>spark-connector_2.11</artifactId>
<packaging>jar</packaging>
+ <version>1.9.0-rc2</version>
+ <name>spark-tensorflow-connector</name>
+ <url>https://www.tensorflow.org</url>
+ <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
+
+ <licenses>
+ <license>
+ <name>The Apache Software License, Version 2.0</name>
+ <url>http://www.apache.org/licenses/LICENSE-2.0.txt</url>
+ <distribution>repo</distribution>
+ </license>
+ </licenses>
<scm>
<url>https://github.com/tensorflow/ecosystem.git</url>
@@ -14,11 +25,293 @@
<developerConnection>scm:git:https://github.com/tensorflow/ecosystem.git</developerConnection>
</scm>
- <url>https://github.com/tensorflow/ecosystem/</url>
- <parent>
- <groupId>org.tensorflow</groupId>
- <artifactId>parentpom</artifactId>
- <version>1.9.0-rc0</version>
- <relativePath>../</relativePath>
- </parent>
-</project> \ No newline at end of file
+ <properties>
+ <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
+ <scala.maven.version>3.2.2</scala.maven.version>
+ <scala.binary.version>2.11</scala.binary.version>
+ <scalatest.maven.version>1.0</scalatest.maven.version>
+ <scala.test.version>2.2.6</scala.test.version>
+ <maven.compiler.version>3.0</maven.compiler.version>
+ <java.version>1.8</java.version>
+ <spark.version>2.3.0</spark.version>
+ <yarn.api.version>2.7.3</yarn.api.version>
+ <junit.version>4.11</junit.version>
+ </properties>
+
+ <build>
+ <pluginManagement>
+ <plugins>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ <version>${scala.maven.version}</version>
+ <executions>
+ <execution>
+ <id>compile</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>compile</goal>
+ </goals>
+ <configuration>
+ <jvmArgs>
+ <jvmArg>-Xms256m</jvmArg>
+ <jvmArg>-Xmx512m</jvmArg>
+ </jvmArgs>
+ <args>
+ <arg>-g:vars</arg>
+ <arg>-deprecation</arg>
+ <arg>-feature</arg>
+ <arg>-unchecked</arg>
+ <arg>-Xfatal-warnings</arg>
+ <arg>-language:implicitConversions</arg>
+ <arg>-language:existentials</arg>
+ </args>
+ </configuration>
+ </execution>
+ <execution>
+ <id>test</id>
+ <goals>
+ <goal>add-source</goal>
+ <goal>testCompile</goal>
+ </goals>
+ </execution>
+ </executions>
+ <configuration>
+ <recompileMode>incremental</recompileMode>
+ <useZincServer>true</useZincServer>
+ <scalaVersion>${scala.binary.version}</scalaVersion>
+ <checkMultipleScalaVersions>false</checkMultipleScalaVersions>
+ </configuration>
+ </plugin>
+ <plugin>
+ <inherited>true</inherited>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ <version>${scalatest.maven.version}</version>
+ <executions>
+ <execution>
+ <id>scalaTest</id>
+ <phase>test</phase>
+ <goals>
+ <goal>test</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- Shade protobuf dependency. -->
+ <plugin>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>3.1.0</version>
+ <executions>
+ <execution>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <minimizeJar>true</minimizeJar>
+ <artifactSet>
+ <includes>
+ <include>com.google.protobuf:protobuf-java</include>
+ <include>org.tensorflow:hadoop</include>
+ <include>org.tensorflow:proto</include>
+ </includes>
+ </artifactSet>
+ <filters>
+ <filter>
+ <!-- Remove the source to keep the result smaller. -->
+ <artifact>com.google.protobuf:protobuf-java</artifact>
+ <excludes>
+ <exclude>**/*.java</exclude>
+ </excludes>
+ </filter>
+ </filters>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>
+ org.tensorflow.spark.shaded.com.google.protobuf
+ </shadedPattern>
+ </relocation>
+ </relocations>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <!-- GPG signed components: http://central.sonatype.org/pages/apache-maven.html#gpg-signed-components -->
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ <version>1.5</version>
+ <executions>
+ <execution>
+ <id>sign-artifacts</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>sign</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </pluginManagement>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest-maven-plugin</artifactId>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-compiler-plugin</artifactId>
+ <version>${maven.compiler.version}</version>
+ <configuration>
+ <source>${java.version}</source>
+ <target>${java.version}</target>
+ </configuration>
+ </plugin>
+ </plugins>
+ </build>
+
+ <profiles>
+ <profile>
+ <id>test</id>
+ <activation>
+ <activeByDefault>true</activeByDefault>
+ <property>
+ <name>!NEVERSETME</name>
+ </property>
+ </activation>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>net.alchim31.maven</groupId>
+ <artifactId>scala-maven-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ <dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <version>${scala.test.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </dependencyManagement>
+ <dependencies>
+ <dependency>
+ <groupId>org.scalatest</groupId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ </profile>
+
+ <!-- Two profiles are used:
+ ossrh - deploys to ossrh/maven central
+ bintray - deploys to bintray/jcenter. -->
+ <profile>
+ <id>ossrh</id>
+ <distributionManagement>
+ <!-- Sonatype requirements from http://central.sonatype.org/pages/apache-maven.html -->
+ <snapshotRepository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/content/repositories/snapshots</url>
+ </snapshotRepository>
+ <repository>
+ <id>ossrh</id>
+ <url>https://oss.sonatype.org/service/local/staging/deploy/maven2/</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ <profile>
+ <id>bintray</id>
+ <distributionManagement>
+ <!-- https://blog.bintray.com/2015/09/17/publishing-your-maven-project-to-bintray/ -->
+ <repository>
+ <id>bintray</id>
+ <url>https://api.bintray.com/maven/google/tensorflow/tensorflow/;publish=0</url>
+ </repository>
+ </distributionManagement>
+ <build>
+ <plugins>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-gpg-plugin</artifactId>
+ </plugin>
+ </plugins>
+ </build>
+ </profile>
+ </profiles>
+
+ <developers>
+ <developer>
+ <name>TensorFlowers</name>
+ <organization>TensorFlow</organization>
+ <organizationUrl>http://www.tensorflow.org</organizationUrl>
+ </developer>
+ </developers>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.tensorflow</groupId>
+ <artifactId>hadoop</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-core_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-sql_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.hadoop</groupId>
+ <artifactId>hadoop-yarn-api</artifactId>
+ <version>${yarn.api.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-mllib_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <version>${junit.version}</version>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+</project>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 157c4b8e82..2a8e640dbc 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0-rc1</version>
+ <version>1.9.0-rc2</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/gen/cc/java_defs.h b/tensorflow/java/src/gen/cc/java_defs.h
index f5f54bf4d3..d9d6f8adc8 100644
--- a/tensorflow/java/src/gen/cc/java_defs.h
+++ b/tensorflow/java/src/gen/cc/java_defs.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
#define TENSORFLOW_JAVA_SRC_GEN_CC_JAVA_DEFS_H_
-#include <string>
#include <list>
#include <map>
+#include <string>
#include <utility>
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_generator.cc b/tensorflow/java/src/gen/cc/op_generator.cc
index 2df69ee299..d5bd99bdd9 100644
--- a/tensorflow/java/src/gen/cc/op_generator.cc
+++ b/tensorflow/java/src/gen/cc/op_generator.cc
@@ -36,20 +36,21 @@ namespace java {
namespace {
constexpr const char kLicense[] =
- "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
- "\n"
- "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
- "you may not use this file except in compliance with the License.\n"
- "You may obtain a copy of the License at\n"
- "\n"
- " http://www.apache.org/licenses/LICENSE-2.0\n"
- "\n"
- "Unless required by applicable law or agreed to in writing, software\n"
- "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
- "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
- "See the License for the specific language governing permissions and\n"
- "limitations under the License.\n"
- "=======================================================================*/\n";
+ "/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.\n"
+ "\n"
+ "Licensed under the Apache License, Version 2.0 (the \"License\");\n"
+ "you may not use this file except in compliance with the License.\n"
+ "You may obtain a copy of the License at\n"
+ "\n"
+ " http://www.apache.org/licenses/LICENSE-2.0\n"
+ "\n"
+ "Unless required by applicable law or agreed to in writing, software\n"
+ "distributed under the License is distributed on an \"AS IS\" BASIS,\n"
+ "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n"
+ "See the License for the specific language governing permissions and\n"
+ "limitations under the License.\n"
+ "=======================================================================*/"
+ "\n";
// There is three different modes to render an op class, depending on the
// number and type of outputs it has:
diff --git a/tensorflow/java/src/gen/cc/op_generator.h b/tensorflow/java/src/gen/cc/op_generator.h
index 759d800ecf..05decd6b54 100644
--- a/tensorflow/java/src/gen/cc/op_generator.h
+++ b/tensorflow/java/src/gen/cc/op_generator.h
@@ -19,10 +19,10 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/java/src/gen/cc/op_specs.h"
namespace tensorflow {
diff --git a/tensorflow/java/src/gen/cc/op_specs.cc b/tensorflow/java/src/gen/cc/op_specs.cc
index 63e99fbb04..941ab2699c 100644
--- a/tensorflow/java/src/gen/cc/op_specs.cc
+++ b/tensorflow/java/src/gen/cc/op_specs.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include <map>
-#include <vector>
#include <string>
#include <utility>
+#include <vector>
#include "re2/re2.h"
#include "tensorflow/core/framework/op.h"
@@ -50,7 +50,7 @@ class TypeResolver {
// For example, if the argument's datatype is DT_STRING, this method will
// return "java.lang.String", so the argument can become "Operand<String>"
// in the Ops API
- Type TypeOf(const OpDef_ArgDef& arg_def, bool *iterable_out);
+ Type TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out);
// Returns types of an input attribute
//
@@ -62,7 +62,7 @@ class TypeResolver {
// <java.lang.Float, float>, so the attribute can be used as a "Float" object
// in the Ops API and casted to a "float" when passing through the JNI layer.
std::pair<Type, Type> TypesOf(const OpDef_AttrDef& attr_def,
- bool *iterable_out);
+ bool* iterable_out);
// Returns true if the type of this attribute has already been resolved
bool IsAttributeVisited(const string& attr_name) {
@@ -89,8 +89,7 @@ class TypeResolver {
}
};
-Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
- bool* iterable_out) {
+Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def, bool* iterable_out) {
*iterable_out = false;
if (!arg_def.number_attr().empty()) {
// when number_attr is set, argument has to be a list of tensors
@@ -154,13 +153,13 @@ Type TypeResolver::TypeOf(const OpDef_ArgDef& arg_def,
} else {
LOG(FATAL) << "Cannot resolve data type of argument \"" << arg_def.name()
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
return type;
}
std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
- bool* iterable_out) {
+ bool* iterable_out) {
std::pair<Type, Type> types = MakeTypePair(Type::Wildcard());
*iterable_out = false;
StringPiece attr_type = attr_def.type();
@@ -185,7 +184,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else if (attr_type == "tensor") {
types = MakeTypePair(Type::Class("Tensor", "org.tensorflow")
- .add_parameter(Type::Wildcard()));
+ .add_parameter(Type::Wildcard()));
} else if (attr_type == "type") {
Type type = *iterable_out ? Type::Wildcard() : NextGeneric();
@@ -196,7 +195,7 @@ std::pair<Type, Type> TypeResolver::TypesOf(const OpDef_AttrDef& attr_def,
} else {
LOG(FATAL) << "Cannot resolve data type for attribute \"" << attr_type
- << "\" in operation \"" << op_def_.name() << "\"";
+ << "\" in operation \"" << op_def_.name() << "\"";
}
visited_attrs_.insert(std::make_pair(attr_def.name(), types.first));
return types;
@@ -219,47 +218,43 @@ string SnakeToCamelCase(const string& str, bool upper = false) {
return result;
}
-bool FindAndCut(re2::StringPiece* input, const RE2& expr,
- re2::StringPiece* before_match, re2::StringPiece* ret_match = nullptr) {
- re2::StringPiece match;
- if (!expr.Match(*input, 0, input->size(), RE2::UNANCHORED, &match, 1)) {
- return false;
- }
- before_match->set(input->data(), match.begin() - input->begin());
- input->remove_prefix(match.end() - before_match->begin());
- if (ret_match != nullptr) {
- *ret_match = match;
- }
+bool FindAndCut(string* input, const RE2& expr, string* before_match,
+ string* ret_match = nullptr) {
+ string match;
+ if (!RE2::PartialMatch(*input, expr, &match)) return false;
+ *before_match = input->substr(0, input->find(match));
+ *input = input->substr(before_match->size() + match.size());
+ if (ret_match != nullptr) *ret_match = match;
return true;
}
-string ParseDocumentation(re2::StringPiece input) {
+string ParseDocumentation(const string& inp) {
std::stringstream javadoc_text;
// TODO(karllessard) This is a very minimalist utility method for converting
// markdown syntax, as found in ops descriptions, to Javadoc/html tags. Check
// for alternatives to increase the level of support for markups.
std::vector<string> markups_subexpr;
- markups_subexpr.push_back("\n+\\*\\s+"); // lists
- markups_subexpr.push_back("\n{2,}"); // paragraphs
+ markups_subexpr.push_back("\n+\\*\\s+"); // lists
+ markups_subexpr.push_back("\n{2,}"); // paragraphs
markups_subexpr.push_back("`{3,}\\s*[^\\s\n]*\\s*\n"); // code blocks
- markups_subexpr.push_back("`+"); // inlined code and code blocks
+ markups_subexpr.push_back("`+"); // inlined code and code blocks
markups_subexpr.push_back("\\*{1,2}\\b"); // text emphasis
- markups_subexpr.push_back("\\["); // hyperlinks
- const RE2 markup_expr(str_util::Join(markups_subexpr, "|"));
+ markups_subexpr.push_back("\\["); // hyperlinks
+ const RE2 markup_expr("(" + str_util::Join(markups_subexpr, "|") + ")");
bool in_list = false;
+ string input = inp;
while (true) {
- re2::StringPiece text;
- re2::StringPiece markup;
+ string text, markup;
if (!FindAndCut(&input, markup_expr, &text, &markup)) {
javadoc_text << input;
break; // end of loop
}
javadoc_text << text;
- if (markup.starts_with("\n")) {
+ if (str_util::StartsWith(markup, "\n")) {
javadoc_text << "\n";
- if (markup.contains("*")) {
+ if (str_util::StrContains(markup, "*")) {
// new list item
javadoc_text << (in_list ? "</li>\n" : "<ul>\n") << "<li>\n";
in_list = true;
@@ -267,18 +262,18 @@ string ParseDocumentation(re2::StringPiece input) {
// end of list
javadoc_text << "</li>\n</ul>\n";
in_list = false;
- } else if (!input.starts_with("```")) {
+ } else if (!str_util::StartsWith(input, "```")) {
// new paragraph (not required if a <pre> block follows)
javadoc_text << "<p>\n";
}
- } else if (markup.starts_with("```")) {
+ } else if (str_util::StartsWith(markup, "```")) {
// code blocks
- if (FindAndCut(&input, "```\\s*\n*", &text)) {
+ if (FindAndCut(&input, "(```\\s*\n*)", &text)) {
javadoc_text << "<pre>{@code\n" << text << "}</pre>\n";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("`")) {
+ } else if (str_util::StartsWith("(" + markup + ")", "`")) {
// inlined code
if (FindAndCut(&input, markup, &text)) {
javadoc_text << "{@code " << text << "}";
@@ -287,26 +282,28 @@ string ParseDocumentation(re2::StringPiece input) {
}
} else if (markup == "**") {
// text emphasis (strong)
- if (FindAndCut(&input, "\\b\\*{2}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{2})", &text)) {
javadoc_text << "<b>" << ParseDocumentation(text) << "</b>";
} else {
javadoc_text << markup;
}
} else if (markup == "*") {
// text emphasis (normal)
- if (FindAndCut(&input, "\\b\\*{1}", &text)) {
+ if (FindAndCut(&input, "(\\b\\*{1})", &text)) {
javadoc_text << "<i>" << ParseDocumentation(text) << "</i>";
} else {
javadoc_text << markup;
}
- } else if (markup.starts_with("[")) {
+ } else if (str_util::StartsWith(markup, "[")) {
// hyperlinks
string label;
string link;
- if (RE2::Consume(&input, "([^\\[]+)\\]\\((http.+)\\)", &label, &link)) {
+ if (RE2::PartialMatch(input, "([^\\[]+)\\]\\((http.+)\\)", &label,
+ &link) &&
+ str_util::StartsWith(input, label + link)) {
+ input = input.substr(label.size() + link.size());
javadoc_text << "<a href=\"" << link << "\">"
- << ParseDocumentation(label)
- << "</a>";
+ << ParseDocumentation(label) << "</a>";
} else {
javadoc_text << markup;
}
@@ -319,57 +316,56 @@ string ParseDocumentation(re2::StringPiece input) {
}
ArgumentSpec CreateInput(const OpDef_ArgDef& input_def,
- const ApiDef::Arg& input_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Arg& input_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(input_def, &iterable);
- Type var_type = Type::Interface("Operand", "org.tensorflow")
- .add_parameter(type);
+ Type var_type =
+ Type::Interface("Operand", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::IterableOf(var_type);
}
- return ArgumentSpec(input_api_def.name(),
+ return ArgumentSpec(
+ input_api_def.name(),
Variable::Create(SnakeToCamelCase(input_api_def.rename_to()), var_type),
- type,
- ParseDocumentation(input_api_def.description()),
- iterable);
+ type, ParseDocumentation(input_api_def.description()), iterable);
}
AttributeSpec CreateAttribute(const OpDef_AttrDef& attr_def,
- const ApiDef::Attr& attr_api_def, TypeResolver* type_resolver) {
+ const ApiDef::Attr& attr_api_def,
+ TypeResolver* type_resolver) {
bool iterable = false;
std::pair<Type, Type> types = type_resolver->TypesOf(attr_def, &iterable);
- Type var_type = types.first.kind() == Type::GENERIC ?
- Type::Class("Class").add_parameter(types.first) : types.first;
+ Type var_type = types.first.kind() == Type::GENERIC
+ ? Type::Class("Class").add_parameter(types.first)
+ : types.first;
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return AttributeSpec(attr_api_def.name(),
+ return AttributeSpec(
+ attr_api_def.name(),
Variable::Create(SnakeToCamelCase(attr_api_def.rename_to()), var_type),
- types.first,
- types.second,
- ParseDocumentation(attr_api_def.description()),
- iterable,
- attr_api_def.has_default_value());
+ types.first, types.second, ParseDocumentation(attr_api_def.description()),
+ iterable, attr_api_def.has_default_value());
}
ArgumentSpec CreateOutput(const OpDef_ArgDef& output_def,
- const ApiDef::Arg& output_api, TypeResolver* type_resolver) {
+ const ApiDef::Arg& output_api,
+ TypeResolver* type_resolver) {
bool iterable = false;
Type type = type_resolver->TypeOf(output_def, &iterable);
- Type var_type = Type::Class("Output", "org.tensorflow")
- .add_parameter(type);
+ Type var_type = Type::Class("Output", "org.tensorflow").add_parameter(type);
if (iterable) {
var_type = Type::ListOf(var_type);
}
- return ArgumentSpec(output_api.name(),
+ return ArgumentSpec(
+ output_api.name(),
Variable::Create(SnakeToCamelCase(output_api.rename_to()), var_type),
- type,
- ParseDocumentation(output_api.description()),
- iterable);
+ type, ParseDocumentation(output_api.description()), iterable);
}
EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
- const ApiDef_Endpoint& endpoint_def) {
+ const ApiDef_Endpoint& endpoint_def) {
std::vector<string> name_tokens = str_util::Split(endpoint_def.name(), ".");
string package;
string name;
@@ -377,27 +373,25 @@ EndpointSpec CreateEndpoint(const OpDef& op_def, const ApiDef& api_def,
package = name_tokens.at(0);
name = name_tokens.at(1);
} else {
- package = kDefaultEndpointPackage;
+ package = "core"; // generate unclassified ops in the 'core' package
name = name_tokens.at(0);
}
- return EndpointSpec(package,
- name,
- Javadoc::Create(ParseDocumentation(api_def.summary()))
- .details(ParseDocumentation(api_def.description())));
+ return EndpointSpec(package, name,
+ Javadoc::Create(ParseDocumentation(api_def.summary()))
+ .details(ParseDocumentation(api_def.description())));
}
} // namespace
OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
- OpSpec op(api_def.graph_op_name(),
- api_def.visibility() == ApiDef::HIDDEN,
- op_def.deprecation().explanation());
+ OpSpec op(api_def.graph_op_name(), api_def.visibility() == ApiDef::HIDDEN,
+ op_def.deprecation().explanation());
TypeResolver type_resolver(op_def);
for (const string& next_input_name : api_def.arg_order()) {
for (int i = 0; i < op_def.input_arg().size(); ++i) {
if (op_def.input_arg(i).name() == next_input_name) {
op.inputs_.push_back(CreateInput(op_def.input_arg(i), api_def.in_arg(i),
- &type_resolver));
+ &type_resolver));
break;
}
}
@@ -406,8 +400,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
// do not parse attributes already visited, they have probably been inferred
// before as an input argument type
if (!type_resolver.IsAttributeVisited(op_def.attr(i).name())) {
- AttributeSpec attr = CreateAttribute(op_def.attr(i), api_def.attr(i),
- &type_resolver);
+ AttributeSpec attr =
+ CreateAttribute(op_def.attr(i), api_def.attr(i), &type_resolver);
// attributes with a default value are optional
if (attr.has_default_value() && attr.type().kind() != Type::GENERIC) {
op.optional_attributes_.push_back(attr);
@@ -417,8 +411,8 @@ OpSpec OpSpec::Create(const OpDef& op_def, const ApiDef& api_def) {
}
}
for (int i = 0; i < op_def.output_arg().size(); ++i) {
- op.outputs_.push_back(CreateOutput(op_def.output_arg(i), api_def.out_arg(i),
- &type_resolver));
+ op.outputs_.push_back(
+ CreateOutput(op_def.output_arg(i), api_def.out_arg(i), &type_resolver));
}
for (const auto& endpoint_def : api_def.endpoint()) {
op.endpoints_.push_back(CreateEndpoint(op_def, api_def, endpoint_def));
diff --git a/tensorflow/java/src/gen/cc/op_specs.h b/tensorflow/java/src/gen/cc/op_specs.h
index 3b53c730df..30ecb8ce53 100644
--- a/tensorflow/java/src/gen/cc/op_specs.h
+++ b/tensorflow/java/src/gen/cc/op_specs.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <string>
#include <vector>
-#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/java/src/gen/cc/java_defs.h"
namespace tensorflow {
@@ -38,9 +38,8 @@ class EndpointSpec {
// javadoc: the endpoint class documentation
// TODO(annarev): hardcode depcreated to false until deprecated is possible
EndpointSpec(const string& package, const string& name,
- const Javadoc& javadoc)
- : package_(package), name_(name), javadoc_(javadoc),
- deprecated_(false) {}
+ const Javadoc& javadoc)
+ : package_(package), name_(name), javadoc_(javadoc), deprecated_(false) {}
const string& package() const { return package_; }
const string& name() const { return name_; }
@@ -63,10 +62,13 @@ class ArgumentSpec {
// type: the tensor type of this argument
// description: a description of this argument, in javadoc
// iterable: true if this argument is a list
- ArgumentSpec(const string& op_def_name, const Variable& var,
- const Type& type, const string& description, bool iterable)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable) {}
+ ArgumentSpec(const string& op_def_name, const Variable& var, const Type& type,
+ const string& description, bool iterable)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -94,11 +96,16 @@ class AttributeSpec {
// iterable: true if this attribute is a list
// has_default_value: true if this attribute has a default value if not set
AttributeSpec(const string& op_def_name, const Variable& var,
- const Type& type, const Type& jni_type, const string& description,
- bool iterable, bool has_default_value)
- : op_def_name_(op_def_name), var_(var), type_(type),
- description_(description), iterable_(iterable),
- jni_type_(jni_type), has_default_value_(has_default_value) {}
+ const Type& type, const Type& jni_type,
+ const string& description, bool iterable,
+ bool has_default_value)
+ : op_def_name_(op_def_name),
+ var_(var),
+ type_(type),
+ description_(description),
+ iterable_(iterable),
+ jni_type_(jni_type),
+ has_default_value_(has_default_value) {}
const string& op_def_name() const { return op_def_name_; }
const Variable& var() const { return var_; }
@@ -147,9 +154,10 @@ class OpSpec {
// hidden: true if this op should not be visible through the Graph Ops API
// deprecation_explanation: message to show if all endpoints are deprecated
explicit OpSpec(const string& graph_op_name, bool hidden,
- const string& deprecation_explanation)
- : graph_op_name_(graph_op_name), hidden_(hidden),
- deprecation_explanation_(deprecation_explanation) {}
+ const string& deprecation_explanation)
+ : graph_op_name_(graph_op_name),
+ hidden_(hidden),
+ deprecation_explanation_(deprecation_explanation) {}
const string graph_op_name_;
const bool hidden_;
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 3524160d87..796d6a62dc 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -15,6 +15,18 @@ limitations under the License.
package org.tensorflow.processor;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Strings;
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.Multimap;
+import com.squareup.javapoet.ClassName;
+import com.squareup.javapoet.FieldSpec;
+import com.squareup.javapoet.JavaFile;
+import com.squareup.javapoet.MethodSpec;
+import com.squareup.javapoet.ParameterSpec;
+import com.squareup.javapoet.TypeName;
+import com.squareup.javapoet.TypeSpec;
+import com.squareup.javapoet.TypeVariableName;
import java.io.IOException;
import java.util.Collection;
import java.util.Collections;
@@ -23,7 +35,6 @@ import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
-
import javax.annotation.processing.AbstractProcessor;
import javax.annotation.processing.Filer;
import javax.annotation.processing.Messager;
@@ -44,19 +55,6 @@ import javax.lang.model.util.ElementFilter;
import javax.lang.model.util.Elements;
import javax.tools.Diagnostic.Kind;
-import com.google.common.base.CaseFormat;
-import com.google.common.base.Strings;
-import com.google.common.collect.HashMultimap;
-import com.google.common.collect.Multimap;
-import com.squareup.javapoet.ClassName;
-import com.squareup.javapoet.FieldSpec;
-import com.squareup.javapoet.JavaFile;
-import com.squareup.javapoet.MethodSpec;
-import com.squareup.javapoet.ParameterSpec;
-import com.squareup.javapoet.TypeName;
-import com.squareup.javapoet.TypeSpec;
-import com.squareup.javapoet.TypeVariableName;
-
/**
* A compile-time Processor that aggregates classes annotated with {@link
* org.tensorflow.op.annotation.Operator} and generates the {@code Ops} convenience API. Please
@@ -115,10 +113,12 @@ public final class OperatorProcessor extends AbstractProcessor {
// generated our code, flag the location of each such class.
if (hasRun) {
for (Element e : annotated) {
- error(e, "The Operator processor has already processed @Operator annotated sources\n" +
- "and written out an Ops API. It cannot process additional @Operator sources.\n" +
- "One reason this can happen is if other annotation processors generate\n" +
- "new @Operator source files.");
+ error(
+ e,
+ "The Operator processor has already processed @Operator annotated sources\n"
+ + "and written out an Ops API. It cannot process additional @Operator sources.\n"
+ + "One reason this can happen is if other annotation processors generate\n"
+ + "new @Operator source files.");
}
return true;
}
@@ -146,9 +146,11 @@ public final class OperatorProcessor extends AbstractProcessor {
return Collections.singleton("org.tensorflow.op.annotation.Operator");
}
- private static final Pattern JAVADOC_TAG_PATTERN = Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
+ private static final Pattern JAVADOC_TAG_PATTERN =
+ Pattern.compile("@(?:param|return|throws|exception|see)\\s+.*");
private static final TypeName T_OPS = ClassName.get("org.tensorflow.op", "Ops");
- private static final TypeName T_OPERATOR = ClassName.get("org.tensorflow.op.annotation", "Operator");
+ private static final TypeName T_OPERATOR =
+ ClassName.get("org.tensorflow.op.annotation", "Operator");
private static final TypeName T_SCOPE = ClassName.get("org.tensorflow.op", "Scope");
private static final TypeName T_GRAPH = ClassName.get("org.tensorflow", "Graph");
private static final TypeName T_STRING = ClassName.get(String.class);
@@ -167,20 +169,17 @@ public final class OperatorProcessor extends AbstractProcessor {
private void write(TypeSpec spec) {
try {
- JavaFile.builder("org.tensorflow.op", spec)
- .skipJavaLangImports(true)
- .build()
- .writeTo(filer);
+ JavaFile.builder("org.tensorflow.op", spec).skipJavaLangImports(true).build().writeTo(filer);
} catch (IOException e) {
throw new AssertionError(e);
}
}
private void writeApi(Multimap<String, MethodSpec> groupedMethods) {
- Map<String, ClassName> groups = new HashMap<String, ClassName>();
-
+ Map<String, ClassName> groups = new HashMap<>();
+
// Generate a API class for each group collected other than the default one (= empty string)
- for (Map.Entry<String, Collection<MethodSpec>> entry: groupedMethods.asMap().entrySet()) {
+ for (Map.Entry<String, Collection<MethodSpec>> entry : groupedMethods.asMap().entrySet()) {
if (!entry.getKey().isEmpty()) {
TypeSpec groupClass = buildGroupClass(entry.getKey(), entry.getValue());
write(groupClass);
@@ -193,12 +192,17 @@ public final class OperatorProcessor extends AbstractProcessor {
}
private boolean collectOpsMethods(
- RoundEnvironment roundEnv, Multimap<String, MethodSpec> groupedMethods, TypeElement annotation) {
+ RoundEnvironment roundEnv,
+ Multimap<String, MethodSpec> groupedMethods,
+ TypeElement annotation) {
boolean result = true;
for (Element e : roundEnv.getElementsAnnotatedWith(annotation)) {
// @Operator can only apply to types, so e must be a TypeElement.
if (!(e instanceof TypeElement)) {
- error(e, "@Operator can only be applied to classes, but this is a %s", e.getKind().toString());
+ error(
+ e,
+ "@Operator can only be applied to classes, but this is a %s",
+ e.getKind().toString());
result = false;
continue;
}
@@ -210,38 +214,42 @@ public final class OperatorProcessor extends AbstractProcessor {
}
return result;
}
-
- private void collectOpMethods(Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
+
+ private void collectOpMethods(
+ Multimap<String, MethodSpec> groupedMethods, TypeElement opClass, TypeElement annotation) {
AnnotationMirror am = getAnnotationMirror(opClass, annotation);
String groupName = getAnnotationElementValueAsString("group", am);
String methodName = getAnnotationElementValueAsString("name", am);
ClassName opClassName = ClassName.get(opClass);
if (Strings.isNullOrEmpty(methodName)) {
- methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
+ methodName = CaseFormat.UPPER_CAMEL.to(CaseFormat.LOWER_CAMEL, opClassName.simpleName());
}
- // Build a method for each @Operator found in the class path. There should be one method per operation factory called
+ // Build a method for each @Operator found in the class path. There should be one method per
+ // operation factory called
// "create", which takes in parameter a scope and, optionally, a list of arguments
for (ExecutableElement opMethod : ElementFilter.methodsIn(opClass.getEnclosedElements())) {
- if (opMethod.getModifiers().contains(Modifier.STATIC) && opMethod.getSimpleName().contentEquals("create")) {
+ if (opMethod.getModifiers().contains(Modifier.STATIC)
+ && opMethod.getSimpleName().contentEquals("create")) {
MethodSpec method = buildOpMethod(methodName, opClassName, opMethod);
groupedMethods.put(groupName, method);
}
}
}
- private MethodSpec buildOpMethod(String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
+ private MethodSpec buildOpMethod(
+ String methodName, ClassName opClassName, ExecutableElement factoryMethod) {
MethodSpec.Builder builder =
MethodSpec.methodBuilder(methodName)
- .addModifiers(Modifier.PUBLIC)
- .returns(TypeName.get(factoryMethod.getReturnType()))
- .varargs(factoryMethod.isVarArgs())
- .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
+ .addModifiers(Modifier.PUBLIC)
+ .returns(TypeName.get(factoryMethod.getReturnType()))
+ .varargs(factoryMethod.isVarArgs())
+ .addJavadoc("$L", buildOpMethodJavadoc(opClassName, factoryMethod));
- for (TypeParameterElement tp: factoryMethod.getTypeParameters()) {
+ for (TypeParameterElement tp : factoryMethod.getTypeParameters()) {
TypeVariableName tvn = TypeVariableName.get((TypeVariable) tp.asType());
builder.addTypeVariable(tvn);
}
- for (TypeMirror thrownType: factoryMethod.getThrownTypes()) {
+ for (TypeMirror thrownType : factoryMethod.getThrownTypes()) {
builder.addException(TypeName.get(thrownType));
}
StringBuilder call = new StringBuilder("return $T.create(scope");
@@ -259,13 +267,17 @@ public final class OperatorProcessor extends AbstractProcessor {
call.append(")");
builder.addStatement(call.toString(), opClassName);
return builder.build();
- }
-
+ }
+
private String buildOpMethodJavadoc(ClassName opClassName, ExecutableElement factoryMethod) {
StringBuilder javadoc = new StringBuilder();
- javadoc.append("Adds an {@link ").append(opClassName.simpleName()).append("} operation to the graph\n\n");
+ javadoc
+ .append("Adds an {@link ")
+ .append(opClassName.simpleName())
+ .append("} operation to the graph\n\n");
- // Add all javadoc tags found in the operator factory method but the first one, which should be in all cases the
+ // Add all javadoc tags found in the operator factory method but the first one, which should be
+ // in all cases the
// 'scope' parameter that is implicitly passed by this API
Matcher tagMatcher = JAVADOC_TAG_PATTERN.matcher(elements.getDocComment(factoryMethod));
boolean firstParam = true;
@@ -277,136 +289,144 @@ public final class OperatorProcessor extends AbstractProcessor {
} else {
javadoc.append(tag).append('\n');
}
- }
+ }
javadoc.append("@see {@link ").append(opClassName).append("}\n");
return javadoc.toString();
}
-
+
private static TypeSpec buildGroupClass(String group, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope");
-
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope");
+
TypeSpec.Builder builder =
TypeSpec.classBuilder(CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, group) + "Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for adding {@code $L} operations to a {@link $T Graph}\n\n" +
- "@see {@link $T}\n", group, T_GRAPH, T_OPS)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for adding {@code $L} operations to a {@link $T Graph}\n\n"
+ + "@see {@link $T}\n",
+ group,
+ T_GRAPH,
+ T_OPS)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
builder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
return builder.build();
}
- private static TypeSpec buildTopClass(Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
+ private static TypeSpec buildTopClass(
+ Map<String, ClassName> groupToClass, Collection<MethodSpec> methods) {
MethodSpec.Builder ctorBuilder =
MethodSpec.constructorBuilder()
- .addModifiers(Modifier.PRIVATE)
- .addParameter(T_SCOPE, "scope")
- .addStatement("this.scope = scope", T_SCOPE);
+ .addModifiers(Modifier.PRIVATE)
+ .addParameter(T_SCOPE, "scope")
+ .addStatement("this.scope = scope", T_SCOPE);
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
ctorBuilder.addStatement("$L = new $T(scope)", entry.getKey(), entry.getValue());
}
TypeSpec.Builder opsBuilder =
TypeSpec.classBuilder("Ops")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .addJavadoc("An API for building a {@link $T} with operation wrappers\n<p>\n" +
- "Any operation wrapper found in the classpath properly annotated as an {@link $T @Operator} is exposed\n" +
- "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n" +
- "try (Graph g = new Graph()) {\n" +
- " Ops ops = new Ops(g);\n" +
- " // Operations are typed classes with convenience\n" +
- " // builders in Ops.\n" +
- " Constant three = ops.constant(3);\n" +
- " // Single-result operations implement the Operand\n" +
- " // interface, so this works too.\n" +
- " Operand four = ops.constant(4);\n" +
- " // Most builders are found within a group, and accept\n" +
- " // Operand types as operands\n" +
- " Operand nine = ops.math().add(four, ops.constant(5));\n" +
- " // Multi-result operations however offer methods to\n" +
- " // select a particular result for use.\n" +
- " Operand result = \n" +
- " ops.math().add(ops.array().unique(s, a).y(), b);\n" +
- " // Optional attributes\n" +
- " ops.math().matMul(a, b, MatMul.transposeA(true));\n" +
- " // Naming operators\n" +
- " ops.withName(“foo”).constant(5); // name “foo”\n" +
- " // Names can exist in a hierarchy\n" +
- " Ops sub = ops.withSubScope(“sub”);\n" +
- " sub.withName(“bar”).constant(4); // “sub/bar”\n" +
- "}\n" +
- "}</pre>\n", T_GRAPH, T_OPERATOR)
- .addMethods(methods)
- .addMethod(ctorBuilder.build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .addJavadoc(
+ "An API for building a {@link $T} with operation wrappers\n<p>\n"
+ + "Any operation wrapper found in the classpath properly annotated as an"
+ + "{@link $T @Operator} is exposed\n"
+ + "by this API or one of its subgroup.\n<p>Example usage:\n<pre>{@code\n"
+ + "try (Graph g = new Graph()) {\n"
+ + " Ops ops = new Ops(g);\n"
+ + " // Operations are typed classes with convenience\n"
+ + " // builders in Ops.\n"
+ + " Constant three = ops.constant(3);\n"
+ + " // Single-result operations implement the Operand\n"
+ + " // interface, so this works too.\n"
+ + " Operand four = ops.constant(4);\n"
+ + " // Most builders are found within a group, and accept\n"
+ + " // Operand types as operands\n"
+ + " Operand nine = ops.math().add(four, ops.constant(5));\n"
+ + " // Multi-result operations however offer methods to\n"
+ + " // select a particular result for use.\n"
+ + " Operand result = \n"
+ + " ops.math().add(ops.array().unique(s, a).y(), b);\n"
+ + " // Optional attributes\n"
+ + " ops.math().matMul(a, b, MatMul.transposeA(true));\n"
+ + " // Naming operators\n"
+ + " ops.withName(“foo”).constant(5); // name “foo”\n"
+ + " // Names can exist in a hierarchy\n"
+ + " Ops sub = ops.withSubScope(“sub”);\n"
+ + " sub.withName(“bar”).constant(4); // “sub/bar”\n"
+ + "}\n"
+ + "}</pre>\n",
+ T_GRAPH,
+ T_OPERATOR)
+ .addMethods(methods)
+ .addMethod(ctorBuilder.build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withSubScope")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "childScopeName")
- .returns(T_OPS)
- .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
- .addJavadoc(
- "Returns an API that adds operations to the graph with the provided name prefix.\n\n" +
- "@see {@link $T#withSubScope(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "childScopeName")
+ .returns(T_OPS)
+ .addStatement("return new $T(scope.withSubScope(childScopeName))", T_OPS)
+ .addJavadoc(
+ "Returns an API that adds operations to the graph with the provided name prefix.\n"
+ + "\n@see {@link $T#withSubScope(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("withName")
- .addModifiers(Modifier.PUBLIC)
- .addParameter(T_STRING, "opName")
- .returns(T_OPS)
- .addStatement("return new Ops(scope.withName(opName))")
- .addJavadoc(
- "Returns an API that uses the provided name for an op.\n\n" +
- "@see {@link $T#withName(String)}\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC)
+ .addParameter(T_STRING, "opName")
+ .returns(T_OPS)
+ .addStatement("return new Ops(scope.withName(opName))")
+ .addJavadoc(
+ "Returns an API that uses the provided name for an op.\n\n"
+ + "@see {@link $T#withName(String)}\n",
+ T_SCOPE)
+ .build());
opsBuilder.addField(
- FieldSpec.builder(T_SCOPE, "scope")
- .addModifiers(Modifier.PRIVATE, Modifier.FINAL)
- .build());
+ FieldSpec.builder(T_SCOPE, "scope").addModifiers(Modifier.PRIVATE, Modifier.FINAL).build());
opsBuilder.addMethod(
MethodSpec.methodBuilder("scope")
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(T_SCOPE)
- .addStatement("return scope")
- .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(T_SCOPE)
+ .addStatement("return scope")
+ .addJavadoc("Returns the current {@link $T scope} of this API\n", T_SCOPE)
+ .build());
- for (Map.Entry<String, ClassName> entry: groupToClass.entrySet()) {
+ for (Map.Entry<String, ClassName> entry : groupToClass.entrySet()) {
opsBuilder.addField(
FieldSpec.builder(entry.getValue(), entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .build());
-
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .build());
+
opsBuilder.addMethod(
MethodSpec.methodBuilder(entry.getKey())
- .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
- .returns(entry.getValue())
- .addStatement("return $L", entry.getKey())
- .addJavadoc("Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.FINAL)
+ .returns(entry.getValue())
+ .addStatement("return $L", entry.getKey())
+ .addJavadoc(
+ "Returns an API for adding {@code $L} operations to the graph\n", entry.getKey())
+ .build());
}
opsBuilder.addMethod(
MethodSpec.methodBuilder("create")
- .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
- .addParameter(T_GRAPH, "graph")
- .returns(T_OPS)
- .addStatement("return new Ops(new $T(graph))", T_SCOPE)
- .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
- .build());
+ .addModifiers(Modifier.PUBLIC, Modifier.STATIC)
+ .addParameter(T_GRAPH, "graph")
+ .returns(T_OPS)
+ .addStatement("return new Ops(new $T(graph))", T_SCOPE)
+ .addJavadoc("Creates an API for adding operations to the provided {@code graph}\n")
+ .build());
return opsBuilder.build();
}
@@ -417,12 +437,16 @@ public final class OperatorProcessor extends AbstractProcessor {
return am;
}
}
- throw new IllegalArgumentException("Annotation " + annotation.getSimpleName() + " not present on element "
- + element.getSimpleName());
+ throw new IllegalArgumentException(
+ "Annotation "
+ + annotation.getSimpleName()
+ + " not present on element "
+ + element.getSimpleName());
}
-
+
private static String getAnnotationElementValueAsString(String elementName, AnnotationMirror am) {
- for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry : am.getElementValues().entrySet()) {
+ for (Map.Entry<? extends ExecutableElement, ? extends AnnotationValue> entry :
+ am.getElementValues().entrySet()) {
if (entry.getKey().getSimpleName().contentEquals(elementName)) {
return entry.getValue().getValue().toString();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index d4fd3db5f7..7d19696749 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable {
}
}
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
+ * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
+ * <p>
+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
+ * shapes in {@code y}.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ Output<?>[] dy = new Output<?>[x.length];
+ final long[] yHandles = new long[y.length];
+ final int[] yIndices = new int[y.length];
+ final long[] xHandles = new long[x.length];
+ final int[] xIndices = new int[x.length];
+ long[] dxHandles = null;
+ int[] dxIndices = null;
+
+ try (Reference ref = ref()) {
+ for (int i = 0; i < y.length; ++i) {
+ yHandles[i] = y[i].op().getUnsafeNativeHandle();
+ yIndices[i] = y[i].index();
+ }
+ for (int i = 0; i < x.length; ++i) {
+ xHandles[i] = x[i].op().getUnsafeNativeHandle();
+ xIndices[i] = x[i].index();
+ }
+ if (dx != null && dx.length > 0) {
+ dxHandles = new long[dx.length];
+ dxIndices = new int[dx.length];
+
+ for (int i = 0; i < dx.length; ++i) {
+ dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
+ dxIndices[i] = dx[i].index();
+ }
+ }
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
+ // of the gradient operations while the second holds the index of their output
+ // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
+ long[] dyHandlesAndIndices =
+ addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ int ndy = dyHandlesAndIndices.length >> 1;
+ if (ndy != dy.length) {
+ throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
+ + " were expected");
+ }
+ for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
+ Operation op = new Operation(this, dyHandlesAndIndices[i]);
+ dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
+ }
+ }
+ return dy;
+ }
+
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code dy/dx_1, dy/dx_2...}
+ * <p>
+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
+ * a single output and {@code dx} is null.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
+ return addGradients(new Output<?>[]{y}, x, null);
+ }
+
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;
@@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
+ private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+
static {
TensorFlow.init();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Input.java b/tensorflow/java/src/main/java/org/tensorflow/Input.java
new file mode 100644
index 0000000000..13bc463e7d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/Input.java
@@ -0,0 +1,48 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow;
+
+/**
+ * Interface implemented by operands of a TensorFlow operation.
+ *
+ * <p>Example usage:
+ *
+ * <pre>{@code
+ * // The "decodeJpeg" operation can be used as input to the "cast" operation
+ * Input decodeJpeg = ops.image().decodeJpeg(...);
+ * ops.math().cast(decodeJpeg, DataType.FLOAT);
+ *
+ * // The output "y" of the "unique" operation can be used as input to the "cast" operation
+ * Output y = ops.array().unique(...).y();
+ * ops.math().cast(y, DataType.FLOAT);
+ *
+ * // The "split" operation can be used as input list to the "concat" operation
+ * Iterable<? extends Input> split = ops.array().split(...);
+ * ops.array().concat(0, split);
+ * }</pre>
+ */
+public interface Input<T> {
+
+ /**
+ * Returns the symbolic handle of a tensor.
+ *
+ * <p>Inputs to TensorFlow operations are outputs of another TensorFlow operation. This method is
+ * used to obtain a symbolic handle that represents the computation of the input.
+ *
+ * @see OperationBuilder#addInput(Output)
+ */
+ Output<T> asOutput();
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
new file mode 100644
index 0000000000..f4671c8af9
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Operands;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss
+ * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}.
+ * <p>
+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all
+ * shapes in {@code y}.
+ * <p>
+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}.
+ * <p>
+ * Example of usage:
+ * <pre>{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ *
+ * Constant<Float> alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
+ * }</pre>
+ */
+@Operator
+public class Gradients implements Op, Iterable<Operand<?>> {
+
+ /**
+ * Optional attributes for {@link Gradients}
+ */
+ public static class Options {
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return this option builder
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ this.dx = dx;
+ return this;
+ }
+
+ private Iterable<Operand<?>> dx;
+
+ private Options() {
+ }
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * @param scope current graph scope
+ * @param y outputs of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ Output<?>[] dx = null;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.dx != null) {
+ dx = Operands.asOutputs(opts.dx);
+ }
+ }
+ }
+ Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(gradOutputs));
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
+ * a single output.
+ *
+ * @param scope current graph scope
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ return create(scope, (Iterable) Arrays.asList(y), x, options);
+ }
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return builder to add more options to this operation
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ return new Options().dx(dx);
+ }
+
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public Iterator<Operand<?>> iterator() {
+ return (Iterator) dy.iterator();
+ }
+
+ /**
+ * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x}
+ */
+ public List<Output<?>> dy() {
+ return dy;
+ }
+
+ /**
+ * Returns a symbolic handle to one of the gradient operation output
+ * <p>
+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * gradients.<Integer>dy(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param index The index of the output among the gradients added by this operation
+ */
+ @SuppressWarnings("unchecked")
+ public <T> Output<T> dy(int index) {
+ return (Output<T>) dy.get(index);
+ }
+
+ private List<Output<?>> dy;
+
+ private Gradients(List<Output<?>> dy) {
+ this.dy = dy;
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
new file mode 100644
index 0000000000..ab34f6aa12
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a boolean. */
+public class TFBool implements TFType {
+ private TFBool() {}
+ static {
+ Types.typeCodes.put(TFBool.class, DataType.BOOL);
+ }
+ static {
+ Types.scalars.put(TFBool.class, false);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
new file mode 100644
index 0000000000..49e5d9f2f3
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit double precision floating point number. */
+public class TFDouble implements TFType {
+ private TFDouble() {}
+ static {
+ Types.typeCodes.put(TFDouble.class, DataType.DOUBLE);
+ }
+ static {
+ Types.scalars.put(TFDouble.class, 0.0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
new file mode 100644
index 0000000000..8426ee41f0
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit single precision floating point number. */
+public class TFFloat implements TFType {
+ private TFFloat() {}
+ static {
+ Types.typeCodes.put(TFFloat.class, DataType.FLOAT);
+ }
+ static {
+ Types.scalars.put(TFFloat.class, 0f);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
new file mode 100644
index 0000000000..3947b6ad09
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 32-bit signed integer. */
+public class TFInt32 implements TFType {
+ private TFInt32() {}
+ static {
+ Types.typeCodes.put(TFInt32.class, DataType.INT32);
+ }
+ static {
+ Types.scalars.put(TFInt32.class, 0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
new file mode 100644
index 0000000000..ccdded8693
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents a 64-bit signed integer. */
+public class TFInt64 implements TFType {
+ private TFInt64() {}
+ static {
+ Types.typeCodes.put(TFInt64.class, DataType.INT64);
+ }
+ static {
+ Types.scalars.put(TFInt64.class, 0L);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
new file mode 100644
index 0000000000..e7327e8c57
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
@@ -0,0 +1,27 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an arbitrary sequence of bytes. */
+public class TFString implements TFType {
+ private TFString() {}
+ static {
+ Types.typeCodes.put(TFString.class, DataType.STRING);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
new file mode 100644
index 0000000000..562953ac9d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
@@ -0,0 +1,20 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+/**
+ * A marker interface for classes representing TensorFlow types.
+ */
+public interface TFType {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
new file mode 100644
index 0000000000..d7305ca5a8
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
@@ -0,0 +1,30 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+// GENERATED FILE. To update, edit tftypes.pl instead.
+
+package org.tensorflow.types;
+
+import org.tensorflow.DataType;
+
+/** Represents an 8-bit unsigned integer. */
+public class TFUInt8 implements TFType {
+ private TFUInt8() {}
+ static {
+ Types.typeCodes.put(TFUInt8.class, DataType.UINT8);
+ }
+ static {
+ Types.scalars.put(TFUInt8.class, (byte)0);
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
new file mode 100644
index 0000000000..976cd9fd34
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.types;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.tensorflow.DataType;
+
+/**
+ * Utility class for managing the representation of TensorFlow types as Java
+ * types. For each TensorFlow type (e.g., int32), there is a corresponding Java
+ * type (e.g., TFInt32) that represents it at compile time and a corresponding
+ * class object (e.g., TFInt32.class) that represents it at run time. There is
+ * also an enumeration value in DataType that can be used to represent the
+ * type, though that should rarely be required.
+ */
+public class Types {
+
+ private Types() {} // not instantiable
+
+ static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
+
+ /** Returns the DataType value corresponding to a TensorFlow type class. */
+ public static DataType dataType(Class<? extends TFType> c) {
+ DataType dtype = typeCodes.get(c);
+ if (dtype == null) {
+ throw new IllegalArgumentException("" + c + " is not a TensorFlow type.");
+ }
+ return dtype;
+ }
+
+ static final Map<Class<?>, Object> scalars = new HashMap<>();
+
+ /** Returns the zero value of type described by {@code c}, or null if
+ * the type (e.g., string) is not numeric and therefore has no zero value.
+ */
+ public static Object zeroValue(Class<? extends TFType> c) {
+ return scalars.get(c);
+ }
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 0fef155275..dac6a345e9 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/java/src/main/native/graph_jni.h"
#include <limits>
+#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
@@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
TF_DeleteBuffer(buf);
return ret;
}
+
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
+ jlongArray y_handles, jintArray y_indices,
+ jlongArray x_handles, jintArray x_indices,
+ jlongArray dx_handles, jintArray dx_indices) {
+
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return nullptr;
+
+ const jint ny = env->GetArrayLength(y_handles);
+ const jint nx = env->GetArrayLength(x_handles);
+
+ std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
+ std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
+ std::unique_ptr<TF_Output[]> dx(nullptr);
+ std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
+
+ resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
+ resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
+ if (dx_handles != nullptr) {
+ if (env->GetArrayLength(dx_handles) != ny) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d dx handles", ny,
+ env->GetArrayLength(dx_handles));
+ }
+ dx.reset(new TF_Output[ny]);
+ resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
+ }
+ if (env->ExceptionCheck()) return nullptr;
+
+ TF_Status* status = TF_NewStatus();
+ TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
+
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ TF_DeleteStatus(status);
+
+ // returned array contains both op handles and output indices, in pair
+ jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
+ jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
+ for (int i = 0, j = nx; i < nx; ++i, ++j) {
+ TF_Output dy_output = dy.get()[i];
+ dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
+ dy_elems[j] = static_cast<jlong>(dy_output.index);
+ }
+ env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
+
+ return dy_handles_and_indices;
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index dd2e038332..4f87e8d5a7 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
jclass,
jlong);
+/*
+ * Class: org_tensorflow_Graph
+ * Method: name
+ * Signature: (J[J[I[J[I[J[I)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
+ jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
+ jintArray);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc
index 708983fef5..8b11525785 100644
--- a/tensorflow/java/src/main/native/session_jni.cc
+++ b/tensorflow/java/src/main/native/session_jni.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/session_jni.h"
@@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
}
-void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
- jintArray src_index, TF_Output* dst, jint n) {
- if (env->ExceptionCheck()) return;
- jint len = env->GetArrayLength(src_op);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operations", n, len, type);
- return;
- }
- len = env->GetArrayLength(src_index);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operation output indices", n, len,
- type);
- return;
- }
- jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
- jint* indices = env->GetIntArrayElements(src_index, nullptr);
- for (int i = 0; i < n; ++i) {
- if (op_handles[i] == 0) {
- throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
- i, n);
- break;
- }
- dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
- static_cast<int>(indices[i])};
- }
- env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
- env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
-}
-
void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
if (buf == nullptr) return;
TF_DeleteBuffer(buf);
diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc
new file mode 100644
index 0000000000..069ac05a1c
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/java/src/main/native/utils_jni.h"
+
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n) {
+ if (env->ExceptionCheck()) return;
+ jint len = env->GetArrayLength(src_op);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operations", n, len, type);
+ return;
+ }
+ len = env->GetArrayLength(src_index);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operation output indices", n, len,
+ type);
+ return;
+ }
+ jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
+ jint* indices = env->GetIntArrayElements(src_index, nullptr);
+ for (int i = 0; i < n; ++i) {
+ if (op_handles[i] == 0) {
+ throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
+ i, n);
+ break;
+ }
+ dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
+ static_cast<int>(indices[i])};
+ }
+ env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
+ env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
+}
+
+
+
+
diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h
new file mode 100644
index 0000000000..352298e7de
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_
+#define TENSORFLOW_JAVA_UTILS_JNI_H_
+
+#include <jni.h>
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c540299bdc..c2e52c22c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -129,4 +130,106 @@ public class GraphTest {
// expected exception.
}
}
+
+ @Test
+ public void addGradientsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x1);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+ Output<Float> y2 = TestUtil.addN(g, y0, x2);
+
+ Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
+ assertNotNull(grads0);
+ assertEquals(1, grads0.length);
+ assertEquals(DataType.FLOAT, grads0[0].dataType());
+
+ Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
+ assertNotNull(grads1);
+ assertEquals(2, grads1.length);
+ assertEquals(DataType.FLOAT, grads1[0].dataType());
+ assertEquals(DataType.FLOAT, grads1[1].dataType());
+
+ try (Tensor<Float> c1 = Tensors.create(3.0f);
+ Tensor<Float> c2 = Tensors.create(2.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ s.runner()
+ .feed(x1, c1)
+ .feed(x2, c2)
+ .fetch(grads0[0])
+ .fetch(grads1[0])
+ .fetch(grads1[1])
+ .run())) {
+
+ assertEquals(3, outputs.size());
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientSumsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ assertNotNull(grad);
+ assertEquals(1, grad.length);
+ assertEquals(DataType.FLOAT, grad[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(114.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientsWithInitialValuesToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
+ assertNotNull(grad0);
+ assertEquals(1, grad0.length);
+ assertEquals(DataType.FLOAT, grad0[0].dataType());
+
+ Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ assertNotNull(grad1);
+ assertEquals(1, grad1.length);
+ assertEquals(DataType.FLOAT, grad1[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad1[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(108.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ private static Output<?>[] toArray(Output<?>... outputs) {
+ return outputs;
+ }
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index e8cc76c2a6..7d5980bcde 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import java.util.ArrayList;
-import java.util.Collection;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -36,8 +34,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -53,8 +51,8 @@ public class SessionTest {
Output<Integer> feed = g.operation("X").output(0);
Output<Integer> fetch = g.operation("Y").output(0);
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -112,7 +110,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -135,8 +133,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
@@ -164,28 +162,6 @@ public class SessionTest {
Session s = new Session(g, singleThreadConfigProto())) {}
}
- private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
- implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
- super(c);
- }
-
- @Override
- public void close() {
- Exception toThrow = null;
- for (AutoCloseable c : this) {
- try {
- c.close();
- } catch (Exception e) {
- toThrow = e;
- }
- }
- if (toThrow != null) {
- throw new RuntimeException(toThrow);
- }
- }
- }
-
private static byte[] fullTraceRunOptions() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c973b5a3d8..4e84886416 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -16,9 +16,34 @@ limitations under the License.
package org.tensorflow;
import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Collection;
/** Static utility functions. */
public class TestUtil {
+
+ public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
+ implements AutoCloseable {
+ AutoCloseableList(Collection<? extends E> c) {
+ super(c);
+ }
+
+ @Override
+ public void close() {
+ Exception toThrow = null;
+ for (AutoCloseable c : this) {
+ try {
+ c.close();
+ } catch (Exception e) {
+ toThrow = e;
+ }
+ }
+ if (toThrow != null) {
+ throw new RuntimeException(toThrow);
+ }
+ }
+ }
+
public static <T> Output<T> constant(Graph g, String name, Object value) {
try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)
@@ -36,7 +61,7 @@ public class TestUtil {
.<T>output(0);
}
- public static Output<?> addN(Graph g, Output<?>... inputs) {
+ public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
@@ -58,6 +83,13 @@ public class TestUtil {
.setAttr("num_split", numSplit)
.build();
}
+
+ public static <T> Output<T> square(Graph g, String name, Output<T> value) {
+ return g.opBuilder("Square", name)
+ .addInput(value)
+ .build()
+ .<T>output(0);
+ }
public static void transpose_A_times_X(Graph g, int[][] a) {
Output<Integer> aa = constant(g, "A", a);