aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD26
-rw-r--r--tensorflow/java/maven/README.md22
-rw-r--r--tensorflow/java/maven/hadoop/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/run_inside_container.sh68
-rw-r--r--tensorflow/java/maven/spark-connector/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow-android/update.py17
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/DataType.java32
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java64
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java73
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Session.java18
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Tensor.java15
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/Scope.java13
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java513
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java48
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java68
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFString.java27
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFType.java20
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/Types.java52
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc21
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h8
-rw-r--r--tensorflow/java/src/main/native/saved_model_bundle_jni.cc15
-rw-r--r--tensorflow/java/src/main/native/saved_model_bundle_jni.h4
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java34
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java56
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java2
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java107
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java131
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java165
41 files changed, 1323 insertions, 494 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index b2b7ee3fa5..9dce78b9a3 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -295,6 +295,32 @@ tf_java_test(
],
)
+tf_java_test(
+ name = "GradientsTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.GradientsTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
+tf_java_test(
+ name = "ZerosTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.ZerosTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "processor_test_resources",
srcs = glob([
diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md
index 3e030dcd09..cbc64a284f 100644
--- a/tensorflow/java/maven/README.md
+++ b/tensorflow/java/maven/README.md
@@ -151,16 +151,6 @@ conducted in a [Docker](https://www.docker.com) container.
7. Upon successful release, commit changes to all the `pom.xml` files
(which should have the updated version number).
-### Snapshots
-
-If the `TF_VERSION` provided to the `release.sh` script ends in `-SNAPSHOT`,
-then instead of using official release files, the nightly build artifacts from
-https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/,
-https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ and
-https://ci.tensorflow.org/view/Nightly/job/nightly-android
-will be used to upload to the Maven Central snapshots repository. (Note that
-snapshots are only uploaded to Maven Central, not Bintray.)
-
### Skip deploying to a repository
Should you need, setting environment variables `DEPLOY_OSSRH=0` or
@@ -173,12 +163,12 @@ cannot skip deploying to OSSRH for a `-SNAPSHOT` version.
This section provides some pointers around how artifacts are currently
assembled.
-All native and java code is first built and tested on
-a [Tensorflow Jenkins server](https://ci.tensorflow.org/) which run various
-scripts under the [`tools/ci_build`](../../tools/ci_build/) directory. Of
-particular interest may be `tools/ci_build/builds/libtensorflow.sh` which
-bundles Java-related build sources and outputs into archives, and
-`tools/ci_build/builds/android_full.sh` which produces an Android AAR package.
+All native and java code is first built and tested by the release process
+which run various scripts under the [`tools/ci_build`](../../tools/ci_build/)
+directory. Of particular interest may be
+`tools/ci_build/builds/libtensorflow.sh` which bundles Java-related build
+sources and outputs into archives, and `tools/ci_build/builds/android_full.sh`
+which produces an Android AAR package.
Maven artifacts however are not created in Jenkins. Instead, artifacts are
created and deployed externally on-demand, when a maintainer runs the
diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml
index 7391dfb965..7fa751a46a 100644
--- a/tensorflow/java/maven/hadoop/pom.xml
+++ b/tensorflow/java/maven/hadoop/pom.xml
@@ -5,7 +5,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>hadoop</artifactId>
<packaging>jar</packaging>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<name>tensorflow-hadoop</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description>
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index d44bdf8f81..8ecabfd399 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index e8925c6fb1..e03ce32216 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 3bf4a2590c..fee840f547 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index b96dcf2888..0c33819b2b 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 5581d864d7..2af7a5cd2e 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh
index 2240d6b7b9..f4794d68a9 100644
--- a/tensorflow/java/maven/run_inside_container.sh
+++ b/tensorflow/java/maven/run_inside_container.sh
@@ -26,12 +26,6 @@ TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git"
DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}"
DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}"
-IS_SNAPSHOT="false"
-if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then
- IS_SNAPSHOT="true"
- # Bintray does not allow snapshots.
- DEPLOY_BINTRAY="false"
-fi
PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip"
if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then
echo "Must deploy to at least one of Bintray or OSSRH" >&2
@@ -69,11 +63,7 @@ mvn_property() {
}
download_libtensorflow() {
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow-src.jar"
- else
- URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar"
- fi
+ URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar"
curl -L "${URL}" -o /tmp/src.jar
cd "${DIR}/libtensorflow"
jar -xvf /tmp/src.jar
@@ -101,17 +91,9 @@ download_libtensorflow_jni() {
mkdir windows-x86_64
mkdir darwin-x86_64
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/
- # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=mac-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz" | tar -xvz -C darwin-x86_64
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-windows-x86_64.zip" -o /tmp/windows.zip
- else
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
- fi
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip
unzip /tmp/windows.zip -d windows-x86_64
rm -f /tmp/windows.zip
@@ -129,13 +111,7 @@ download_libtensorflow_jni_gpu() {
mkdir linux-x86_64
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/
- # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/
- curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=gpu-linux/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64
- else
- curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
- fi
+ curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64
# Updated timestamps seem to be required to get Maven to pick up the file.
touch linux-x86_64/*
@@ -165,11 +141,7 @@ generate_java_protos() {
rm -f "/tmp/protoc.zip"
# Download the release archive of TensorFlow protos.
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_proto.zip"
- else
- URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip"
- fi
+ URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip"
curl -L "${URL}" -o /tmp/libtensorflow_proto.zip
mkdir -p "${DIR}/proto/tmp/src"
unzip -d "${DIR}/proto/tmp/src" "/tmp/libtensorflow_proto.zip"
@@ -238,11 +210,7 @@ deploy_profile() {
# Determine the correct pom file property to use
# for the repository url.
local rtype
- if [[ "${IS_SNAPSHOT}" == "true" ]]; then
- rtype='snapshotRepository'
- else
- rtype='repository'
- fi
+ rtype='repository'
local url=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.url")
local repositoryId=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.id")
mvn gpg:sign-and-deploy-file \
@@ -300,17 +268,13 @@ mvn verify
deploy_artifacts
set +ex
-if [[ "${IS_SNAPSHOT}" == "false" ]]; then
- echo "Uploaded to the staging repository"
- echo "After validating the release: "
- if [[ "${DEPLOY_OSSRH}" == "true" ]]; then
- echo "* Login to https://oss.sonatype.org/#stagingRepositories"
- echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort"
- fi
- if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then
- echo "* Login to https://bintray.com/google/tensorflow/tensorflow"
- echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort"
- fi
-else
- echo "Uploaded to the snapshot repository"
+echo "Uploaded to the staging repository"
+echo "After validating the release: "
+if [[ "${DEPLOY_OSSRH}" == "true" ]]; then
+ echo "* Login to https://oss.sonatype.org/#stagingRepositories"
+ echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort"
+fi
+if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then
+ echo "* Login to https://bintray.com/google/tensorflow/tensorflow"
+ echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort"
fi
diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml
index 64956be02c..27d9b54c6c 100644
--- a/tensorflow/java/maven/spark-connector/pom.xml
+++ b/tensorflow/java/maven/spark-connector/pom.xml
@@ -6,7 +6,7 @@
<groupId>org.tensorflow</groupId>
<artifactId>spark-connector_2.11</artifactId>
<packaging>jar</packaging>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<name>spark-tensorflow-connector</name>
<url>https://www.tensorflow.org</url>
<description>TensorFlow TFRecord connector for Apache Spark DataFrames</description>
diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py
index 2206d800ca..c620564072 100644
--- a/tensorflow/java/maven/tensorflow-android/update.py
+++ b/tensorflow/java/maven/tensorflow-android/update.py
@@ -86,19 +86,10 @@ def read_template(path):
def main():
args = get_args()
- # Artifacts are downloaded from the ci build. A SNAPSHOT release is
- # associated with artifacts from the last successful nightly build. Otherwise,
- # it comes from the officially blessed release artifacts.
- if args.version.endswith('SNAPSHOT'):
- info_url = ('https://ci.tensorflow.org/view/Nightly/job/nightly-android'
- '/lastSuccessfulBuild/api/json')
- aar_url = None
- build_type = 'nightly-android'
- else:
- release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow'
- info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version)
- aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version)
- build_type = 'release-android'
+ release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow'
+ info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version)
+ aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version)
+ build_type = 'release-android'
# Retrieve build information
build_info = get_json(info_url)
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 92e15aa2c7..c952545bc6 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.9.0</version>
+ <version>1.10.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
index 796d6a62dc..1b7bcdab35 100644
--- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
+++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java
@@ -290,7 +290,7 @@ public final class OperatorProcessor extends AbstractProcessor {
javadoc.append(tag).append('\n');
}
}
- javadoc.append("@see {@link ").append(opClassName).append("}\n");
+ javadoc.append("@see ").append(opClassName).append("\n");
return javadoc.toString();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
index 7b92be6d38..516655040b 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java
@@ -17,40 +17,54 @@ package org.tensorflow;
import java.util.HashMap;
import java.util.Map;
+
import org.tensorflow.types.UInt8;
/** Represents the type of elements in a {@link Tensor} as an enum. */
public enum DataType {
/** 32-bit single precision floating point. */
- FLOAT(1),
+ FLOAT(1, 4),
/** 64-bit double precision floating point. */
- DOUBLE(2),
+ DOUBLE(2, 8),
/** 32-bit signed integer. */
- INT32(3),
+ INT32(3, 4),
/** 8-bit unsigned integer. */
- UINT8(4),
+ UINT8(4, 1),
/**
* A sequence of bytes.
*
* <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes.
*/
- STRING(7),
+ STRING(7, -1),
/** 64-bit signed integer. */
- INT64(9),
+ INT64(9, 8),
/** Boolean. */
- BOOL(10);
+ BOOL(10, 1);
private final int value;
+
+ private final int byteSize;
- // The integer value must match the corresponding TF_* value in the TensorFlow C API.
- DataType(int value) {
+ /**
+ * @param value must match the corresponding TF_* value in the TensorFlow C API.
+ * @param byteSize size of an element of this type, in bytes, -1 if unknown
+ */
+ DataType(int value, int byteSize) {
this.value = value;
+ this.byteSize = byteSize;
+ }
+
+ /**
+ * Returns the size of an element of this type, in bytes, or -1 if element size is variable.
+ */
+ public int byteSize() {
+ return byteSize;
}
/** Corresponding value of the TF_DataType enum in the TensorFlow C API. */
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index 7d19696749..752b49af04 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable {
}
/**
- * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
- * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
- * <p>
- * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
- * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
- * <p>
- * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
- * shapes in {@code y}.
- *
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e.,
+ * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ *
+ * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives
+ * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of
+ * {@code y}.
+ *
+ * <p>If {@code dx} is null, the implementation will use dx of {@link
+ * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}.
+ *
+ * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute
+ * gradients. It must be unique within the provided graph or the operation will fail.
+ *
+ * <p>If {@code prefix} is null, then one will be chosen automatically.
+ *
+ * @param prefix unique string prefix applied before the names of nodes added to the graph to
+ * compute gradients. If null, a default one will be chosen.
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
- public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
Output<?>[] dy = new Output<?>[x.length];
final long[] yHandles = new long[y.length];
final int[] yIndices = new int[y.length];
@@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable {
dxIndices[i] = dx[i].index();
}
}
- // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
- // of the gradient operations while the second holds the index of their output
- // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first
+ // holds the native handles of the gradient operations while the second holds the index of
+ // their output e.g. given
+ // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
// dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
long[] dyHandlesAndIndices =
- addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ addGradients(
+ ref.nativeHandle(),
+ prefix,
+ yHandles,
+ yIndices,
+ xHandles,
+ xIndices,
+ dxHandles,
+ dxIndices);
int ndy = dyHandlesAndIndices.length >> 1;
if (ndy != dy.length) {
throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
@@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable {
/**
* Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
* i.e., {@code dy/dx_1, dy/dx_2...}
- * <p>
+ * <p>
* This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
- * a single output and {@code dx} is null.
- *
+ * a single output, {@code dx} is null and {@code prefix} is null.
+ *
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @return the partial derivatives {@code dy} with the size of {@code x}
*/
public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
- return addGradients(new Output<?>[]{y}, x, null);
+ return addGradients(null, new Output<?>[] {y}, x, null);
}
private final Object nativeHandleLock = new Object();
@@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
- private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
- long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+ private static native long[] addGradients(
+ long handle,
+ String prefix,
+ long[] inputHandles,
+ int[] inputIndices,
+ long[] outputHandles,
+ int[] outputIndices,
+ long[] gradInputHandles,
+ int[] gradInputIndices);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
index c8b9126f03..49594e6b47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java
@@ -25,18 +25,86 @@ package org.tensorflow;
* protocol buffer</a>).
*/
public class SavedModelBundle implements AutoCloseable {
+ /** Options for loading a SavedModel. */
+ public static final class Loader {
+ /** Load a <code>SavedModelBundle</code> with the configured options. */
+ public SavedModelBundle load() {
+ return SavedModelBundle.load(exportDir, tags, configProto, runOptions);
+ }
+
+ /**
+ * Sets options to use when executing model initialization operations.
+ *
+ * @param options Serialized <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
+ * protocol buffer</a>.
+ */
+ public Loader withRunOptions(byte[] options) {
+ this.runOptions = options;
+ return this;
+ }
+
+ /**
+ * Set configuration of the <code>Session</code> object created when loading the model.
+ *
+ * @param configProto Serialized <a
+ * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto
+ * protocol buffer</a>.
+ */
+ public Loader withConfigProto(byte[] configProto) {
+ this.configProto = configProto;
+ return this;
+ }
+
+ /**
+ * Sets the set of tags that identify the specific graph in the saved model to load.
+ *
+ * @param tags the tags identifying the specific MetaGraphDef to load.
+ */
+ public Loader withTags(String... tags) {
+ this.tags = tags;
+ return this;
+ }
+
+ private Loader(String exportDir) {
+ this.exportDir = exportDir;
+ }
+
+ private String exportDir = null;
+ private String[] tags = null;
+ private byte[] configProto = null;
+ private byte[] runOptions = null;
+ }
/**
* Load a saved model from an export directory. The model that is being loaded should be created
* using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model
* API</a>.
*
+ * <p>This method is a shorthand for:
+ *
+ * <pre>{@code
+ * SavedModelBundler.loader().withTags(tags).load();
+ * }</pre>
+ *
* @param exportDir the directory path containing a saved model.
* @param tags the tags identifying the specific metagraphdef to load.
* @return a bundle containing the graph and associated session.
*/
public static SavedModelBundle load(String exportDir, String... tags) {
- return load(exportDir, tags, null);
+ return loader(exportDir).withTags(tags).load();
+ }
+
+ /**
+ * Load a saved model.
+ *
+ * <p/>Returns a <code>Loader</code> object that can set configuration options before actually
+ * loading the model,
+ *
+ * @param exportDir the directory path containing a saved model.
+ */
+ public static Loader loader(String exportDir) {
+ return new Loader(exportDir);
}
/**
@@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable {
return new SavedModelBundle(graph, session, metaGraphDef);
}
- private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions);
+ private static native SavedModelBundle load(
+ String exportDir, String[] tags, byte[] config, byte[] runOptions);
static {
TensorFlow.init();
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java
index 73324f23e6..a660d25f98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Session.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java
@@ -185,11 +185,20 @@ public final class Session implements AutoCloseable {
return this;
}
- /** Makes {@link #run()} return the Tensor referred to by {@code output}. */
+ /**
+ * Makes {@link #run()} return the Tensor referred to by {@code output}.
+ */
public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
+
+ /**
+ * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
+ */
+ public Runner fetch(Operand<?> operand) {
+ return fetch(operand.asOutput());
+ }
/**
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s.
@@ -209,6 +218,13 @@ public final class Session implements AutoCloseable {
targets.add(operation);
return this;
}
+
+ /**
+ * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s.
+ */
+ public Runner addTarget(Operand<?> operand) {
+ return addTarget(operand.asOutput().op());
+ }
/**
* (Experimental method): set options (typically for debugging) for this run.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
index 24a3775db6..8987253768 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java
@@ -595,20 +595,11 @@ public final class Tensor<T> implements AutoCloseable {
}
private static int elemByteSize(DataType dataType) {
- switch (dataType) {
- case FLOAT:
- case INT32:
- return 4;
- case DOUBLE:
- case INT64:
- return 8;
- case BOOL:
- case UINT8:
- return 1;
- case STRING:
+ int size = dataType.byteSize();
+ if (size < 0) {
throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
}
- throw new IllegalArgumentException("DataType " + dataType + " is not supported yet");
+ return size;
}
private static void throwExceptionIfNotByteOfByteArrays(Object array) {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
index 8de2eaeb79..5a233bcc98 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java
@@ -135,17 +135,8 @@ public final class Scope {
* }</pre>
*
* <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a
- * set of related operations to the graph by calling other operator building code) you should also
- * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a
- * meaningful name.
- *
- * <pre>{@code
- * public static Stddev create(Scope scope, ...) {
- * // group sub-operations under a common name
- * Scope group = scope.withSubScope("stddev");
- * ... Sqrt.create(group, Mean.create(group, ...))
- * }
- * }</pre>
+ * set of related operations to the graph by calling other operator building code), the provided
+ * name will act as a subscope to all underlying operators.
*
* @param defaultName name for the underlying operator.
* @return unique name for the operator.
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
index de4049f66b..00b6726be3 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -15,11 +15,15 @@ limitations under the License.
package org.tensorflow.op.core;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+import java.nio.charset.Charset;
+
import org.tensorflow.DataType;
import org.tensorflow.Operand;
import org.tensorflow.Operation;
@@ -32,25 +36,82 @@ import org.tensorflow.op.annotation.Operator;
/** An operator producing a constant value. */
@Operator
public final class Constant<T> extends PrimitiveOp implements Operand<T> {
+
/**
- * Create a constant from a Java object.
+ * Creates a constant containing a single {@code int} element.
*
- * <p>The argument {@code object} is first converted into a Tensor using {@link
- * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
- * provided. For example:
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return an integer constant
+ */
+ public static Constant<Integer> create(Scope scope, int data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code int} elements.
*
- * <pre>{@code
- * Constant.create(scope, 7); // returns a constant scalar tensor 7
- * }</pre>
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code int} elements.
*
* @param scope is a scope used to add the underlying operation.
- * @param object a Java object representing the constant.
- * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
*/
- public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
- try (Tensor<T> value = Tensor.create(object, type)) {
- return createWithTensor(scope, value);
- }
+ public static Constant<Integer> create(Scope scope, int[][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][] data) {
+ return create(scope, data, Integer.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code int} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Integer> create(Scope scope, int[][][][][][] data) {
+ return create(scope, data, Integer.class);
}
/**
@@ -64,6 +125,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return an integer constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) {
@@ -73,6 +135,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code float} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a float constant
+ */
+ public static Constant<Float> create(Scope scope, float data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code float} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Float> create(Scope scope, float[][][][][][] data) {
+ return create(scope, data, Float.class);
+ }
+
+ /**
* Create a {@link DataType#FLOAT} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -83,6 +222,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a float constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) {
@@ -92,6 +232,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code double} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a double constant
+ */
+ public static Constant<Double> create(Scope scope, double data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code double} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Double> create(Scope scope, double[][][][][][] data) {
+ return create(scope, data, Double.class);
+ }
+
+ /**
* Create a {@link DataType#DOUBLE} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -102,6 +319,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a double constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) {
@@ -111,6 +329,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code long} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a long constant
+ */
+ public static Constant<Long> create(Scope scope, long data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code long} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Long> create(Scope scope, long[][][][][][] data) {
+ return create(scope, data, Long.class);
+ }
+
+ /**
* Create a {@link DataType#INT64} constant with data from the given buffer.
*
* <p>Creates a constant with the given shape by copying elements from the buffer (starting from
@@ -121,6 +416,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param scope is a scope used to add the underlying operation.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a long constant
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) {
@@ -130,6 +426,174 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
/**
+ * Creates a constant containing a single {@code boolean} element.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The value to put into the new constant.
+ * @return a boolean constant
+ */
+ public static Constant<Boolean> create(Scope scope, boolean data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a rank-6 constant of {@code boolean} elements.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. The dimensions of the
+ * new constant will match those of the array.
+ */
+ public static Constant<Boolean> create(Scope scope, boolean[][][][][][] data) {
+ return create(scope, data, Boolean.class);
+ }
+
+ /**
+ * Creates a {@code String} constant using the default, UTF-8 encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data The string to put into the new constant.
+ * @return a string constant
+ */
+ public static Constant<String> create(Scope scope, String data) {
+ return create(scope, data, UTF_8);
+ }
+
+ /**
+ * Creates a {@code String} constant using a specified encoding.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param charset The encoding from String to bytes.
+ * @param data The string to put into the new constant.
+ * @return a string constant
+ */
+ public static Constant<String> create(Scope scope, String data, Charset charset) {
+ try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) {
+ return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class));
+ }
+ }
+
+ /**
+ * Creates a constant containing a single {@code String} element, represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-1 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-2 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-3 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-4 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
+ * Creates a rank-5 constant of {@code String} elements, each represented as an array of {@code byte}s.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param data An array containing the values to put into the new constant. String elements are
+ * sequences of bytes from the last array dimension.
+ */
+ public static Constant<String> create(Scope scope, byte[][][][][][] data) {
+ return create(scope, data, String.class);
+ }
+
+ /**
* Create a constant with data from the given buffer.
*
* <p>Creates a Constant with the provided shape of any type where the constant data has been
@@ -141,6 +605,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
* @param type the tensor datatype.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
+ * @return a constant of type `type`
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
@@ -150,6 +615,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> {
}
}
+ /**
+ * Create a constant from a Java object.
+ *
+ * <p>The argument {@code object} is first converted into a Tensor using {@link
+ * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
+ * provided. For example:
+ *
+ * <pre>{@code
+ * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix
+ * }</pre>
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param object a Java object representing the constant.
+ * @return a constant of type `type`
+ * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ */
+ public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) {
+ try (Tensor<T> value = Tensor.create(object, type)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) {
return new Constant<T>(
scope
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
index f4671c8af9..eea9dc1c47 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -18,7 +18,6 @@ package org.tensorflow.op.core;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
-
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.op.Op;
@@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* Optional attributes for {@link Gradients}
*/
public static class Options {
-
+
/**
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return this option builder
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public Options dx(Iterable<? extends Operand<?>> dx) {
this.dx = dx;
return this;
}
-
- private Iterable<Operand<?>> dx;
-
+
+ private Iterable<? extends Operand<?>> dx;
+
private Options() {
}
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
+ *
* @param scope current graph scope
* @param y outputs of the function to derive
* @param x inputs of the function for which partial derivatives are computed
* @param options carries optional attributes values
* @return a new instance of {@code Gradients}
*/
- public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope,
+ Iterable<? extends Operand<?>> y,
+ Iterable<? extends Operand<?>> x,
+ Options... options) {
Output<?>[] dx = null;
if (options != null) {
for (Options opts : options) {
@@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> {
}
}
}
- Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
- return new Gradients(Arrays.asList(gradOutputs));
+ Output<?>[] dy =
+ scope
+ .graph()
+ .addGradients(
+ scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(dy));
}
/**
* Adds gradients computation ops to the graph according to scope.
- *
- * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
- * a single output.
- *
+ *
+ * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where
+ * {@code y} is a single output.
+ *
* @param scope current graph scope
* @param y output of the function to derive
* @param x inputs of the function for which partial derivatives are computed
@@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @return a new instance of {@code Gradients}
*/
@SuppressWarnings({"unchecked", "rawtypes"})
- public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ public static Gradients create(
+ Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) {
return create(scope, (Iterable) Arrays.asList(y), x, options);
}
@@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> {
* @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
* @return builder to add more options to this operation
*/
- public Options dx(Iterable<Operand<?>> dx) {
+ public static Options dx(Iterable<? extends Operand<?>> dx) {
return new Options().dx(dx);
}
@@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> {
public List<Output<?>> dy() {
return dy;
}
-
+
/**
* Returns a symbolic handle to one of the gradient operation output
- * <p>
- * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ *
+ * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call
* this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
- * gradients.<Integer>dy(0)}
+ * gradients.<Float>dy(0)}
*
* @param <T> The expected element type of the tensors produced by this output.
* @param index The index of the output among the gradients added by this operation
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
new file mode 100644
index 0000000000..b7c6beb9bc
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.op.core;
+
+import java.nio.ByteBuffer;
+
+import org.tensorflow.DataType;
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * An operator creating a constant initialized with zeros of the shape given by `dims`.
+ *
+ * <p>For example, the following expression
+ * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre>
+ * is the equivalent of
+ * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre>
+ *
+ * @param <T> constant type
+ */
+@Operator
+public class Zeros<T> implements Op, Operand<T> {
+
+ /**
+ * Creates a zeroed tensor given its type and shape.
+ *
+ * @param scope is a scope used to add the underlying operation
+ * @param dims a 1-D operand that represents the shape of the output tensor
+ * @param type the output tensor datatype
+ * @return a constant tensor initialized with zeros
+ * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros.
+ */
+ public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) {
+ Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros"
+ int zeroSize = DataType.fromClass(type).byteSize();
+ if (zeroSize < 0) {
+ throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros");
+ }
+ Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize));
+ return new Zeros<T>(Fill.create(childScope, dims, zero));
+ }
+
+ @Override
+ public Output<T> asOutput() {
+ return fill.asOutput();
+ }
+
+ private final Fill<T> fill;
+
+ private Zeros(Fill<T> fill) {
+ this.fill = fill;
+ }
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
deleted file mode 100644
index ab34f6aa12..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a boolean. */
-public class TFBool implements TFType {
- private TFBool() {}
- static {
- Types.typeCodes.put(TFBool.class, DataType.BOOL);
- }
- static {
- Types.scalars.put(TFBool.class, false);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
deleted file mode 100644
index 49e5d9f2f3..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 64-bit double precision floating point number. */
-public class TFDouble implements TFType {
- private TFDouble() {}
- static {
- Types.typeCodes.put(TFDouble.class, DataType.DOUBLE);
- }
- static {
- Types.scalars.put(TFDouble.class, 0.0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
deleted file mode 100644
index 8426ee41f0..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 32-bit single precision floating point number. */
-public class TFFloat implements TFType {
- private TFFloat() {}
- static {
- Types.typeCodes.put(TFFloat.class, DataType.FLOAT);
- }
- static {
- Types.scalars.put(TFFloat.class, 0f);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
deleted file mode 100644
index 3947b6ad09..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 32-bit signed integer. */
-public class TFInt32 implements TFType {
- private TFInt32() {}
- static {
- Types.typeCodes.put(TFInt32.class, DataType.INT32);
- }
- static {
- Types.scalars.put(TFInt32.class, 0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
deleted file mode 100644
index ccdded8693..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 64-bit signed integer. */
-public class TFInt64 implements TFType {
- private TFInt64() {}
- static {
- Types.typeCodes.put(TFInt64.class, DataType.INT64);
- }
- static {
- Types.scalars.put(TFInt64.class, 0L);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
deleted file mode 100644
index e7327e8c57..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
+++ /dev/null
@@ -1,27 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents an arbitrary sequence of bytes. */
-public class TFString implements TFType {
- private TFString() {}
- static {
- Types.typeCodes.put(TFString.class, DataType.STRING);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
deleted file mode 100644
index 562953ac9d..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
+++ /dev/null
@@ -1,20 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-package org.tensorflow.types;
-
-/**
- * A marker interface for classes representing TensorFlow types.
- */
-public interface TFType {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
deleted file mode 100644
index d7305ca5a8..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents an 8-bit unsigned integer. */
-public class TFUInt8 implements TFType {
- private TFUInt8() {}
- static {
- Types.typeCodes.put(TFUInt8.class, DataType.UINT8);
- }
- static {
- Types.scalars.put(TFUInt8.class, (byte)0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
deleted file mode 100644
index 976cd9fd34..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-package org.tensorflow.types;
-
-import java.util.HashMap;
-import java.util.Map;
-import org.tensorflow.DataType;
-
-/**
- * Utility class for managing the representation of TensorFlow types as Java
- * types. For each TensorFlow type (e.g., int32), there is a corresponding Java
- * type (e.g., TFInt32) that represents it at compile time and a corresponding
- * class object (e.g., TFInt32.class) that represents it at run time. There is
- * also an enumeration value in DataType that can be used to represent the
- * type, though that should rarely be required.
- */
-public class Types {
-
- private Types() {} // not instantiable
-
- static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
-
- /** Returns the DataType value corresponding to a TensorFlow type class. */
- public static DataType dataType(Class<? extends TFType> c) {
- DataType dtype = typeCodes.get(c);
- if (dtype == null) {
- throw new IllegalArgumentException("" + c + " is not a TensorFlow type.");
- }
- return dtype;
- }
-
- static final Map<Class<?>, Object> scalars = new HashMap<>();
-
- /** Returns the zero value of type described by {@code c}, or null if
- * the type (e.g., string) is not numeric and therefore has no zero value.
- */
- public static Object zeroValue(Class<? extends TFType> c) {
- return scalars.get(c);
- }
-}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index dac6a345e9..f1744d8769 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
return ret;
}
-JNIEXPORT jlongArray JNICALL
-Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
- jlongArray y_handles, jintArray y_indices,
- jlongArray x_handles, jintArray x_indices,
- jlongArray dx_handles, jintArray dx_indices) {
-
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv* env, jclass clazz, jlong handle, jstring prefix,
+ jlongArray y_handles, jintArray y_indices, jlongArray x_handles,
+ jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) {
TF_Graph* g = requireHandle(env, handle);
if (g == nullptr) return nullptr;
@@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
}
if (env->ExceptionCheck()) return nullptr;
+ const char* cprefix = nullptr;
+ if (prefix != nullptr) {
+ cprefix = env->GetStringUTFChars(prefix, nullptr);
+ }
TF_Status* status = TF_NewStatus();
- TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
-
+ TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(),
+ status, dy.get());
+ if (prefix != nullptr) {
+ env->ReleaseStringUTFChars(prefix, cprefix);
+ }
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index 4f87e8d5a7..215695cdfd 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
/*
* Class: org_tensorflow_Graph
* Method: name
- * Signature: (J[J[I[J[I[J[I)[J
+ * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J
*/
-JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
- jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
- jintArray);
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(
+ JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray,
+ jintArray, jlongArray, jintArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
index de6382a79c..68999fb2da 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc
@@ -22,12 +22,25 @@ limitations under the License.
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags,
- jbyteArray run_options) {
+ jbyteArray config, jbyteArray run_options) {
TF_Status* status = TF_NewStatus();
jobject bundle = nullptr;
// allocate parameters for TF_LoadSessionFromSavedModel
TF_SessionOptions* opts = TF_NewSessionOptions();
+ if (config != nullptr) {
+ size_t sz = env->GetArrayLength(config);
+ if (sz > 0) {
+ jbyte* config_data = env->GetByteArrayElements(config, nullptr);
+ TF_SetConfig(opts, static_cast<void*>(config_data), sz, status);
+ env->ReleaseByteArrayElements(config, config_data, JNI_ABORT);
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteSessionOptions(opts);
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ }
+ }
TF_Buffer* crun_options = nullptr;
if (run_options != nullptr) {
size_t sz = env->GetArrayLength(run_options);
diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
index 6cce6a81bd..a4b05d0409 100644
--- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h
+++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h
@@ -26,10 +26,10 @@ extern "C" {
* Class: org_tensorflow_SavedModelBundle
* Method: load
* Signature:
- * (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle;
+ * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle;
*/
JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load(
- JNIEnv *, jclass, jstring, jobjectArray, jbyteArray);
+ JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray);
#ifdef __cplusplus
} // extern "C"
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c2e52c22c6..7c05c1deaf 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
-
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -180,8 +179,8 @@ public class GraphTest {
Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
Output<Float> y0 = TestUtil.square(g, "y0", x);
Output<Float> y1 = TestUtil.square(g, "y1", y0);
-
- Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+
+ Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null);
assertNotNull(grad);
assertEquals(1, grad.length);
assertEquals(DataType.FLOAT, grad[0].dataType());
@@ -212,7 +211,7 @@ public class GraphTest {
assertEquals(1, grad0.length);
assertEquals(DataType.FLOAT, grad0[0].dataType());
- Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0]));
assertNotNull(grad1);
assertEquals(1, grad1.length);
assertEquals(DataType.FLOAT, grad1[0].dataType());
@@ -228,6 +227,33 @@ public class GraphTest {
}
}
}
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+
+ Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad0[0].op().name().startsWith("gradients/"));
+
+ Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null);
+ assertTrue(grad1[0].op().name().startsWith("gradients_1/"));
+
+ Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad2[0].op().name().startsWith("more_gradients/"));
+
+ Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ assertTrue(grad3[0].op().name().startsWith("even_more_gradients/"));
+
+ try {
+ g.addGradients("even_more_gradients", toArray(y0), toArray(x), null);
+ } catch (IllegalArgumentException e) {
+ // expected exception
+ }
+ }
+ }
private static Output<?>[] toArray(Output<?>... outputs) {
return outputs;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
index 7922f3329c..7d936867a7 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java
@@ -47,7 +47,61 @@ public class SavedModelBundleTest {
fail("not expected");
} catch (org.tensorflow.TensorFlowException e) {
// expected exception
- assertTrue(e.getMessage().contains("SavedModel not found"));
+ assertTrue(e.getMessage().contains("Could not find SavedModel"));
}
}
+
+ @Test
+ public void loader() {
+ try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH)
+ .withTags("serve")
+ .withConfigProto(sillyConfigProto())
+ .withRunOptions(sillyRunOptions())
+ .load()) {
+ assertNotNull(bundle.session());
+ assertNotNull(bundle.graph());
+ assertNotNull(bundle.metaGraphDef());
+ }
+ }
+
+ private static byte[] sillyRunOptions() {
+ // Ideally this would use the generated Java sources for protocol buffers
+ // and end up with something like the snippet below. However, generating
+ // the Java files for the .proto files in tensorflow/core:protos_all is
+ // a bit cumbersome in bazel until the proto_library rule is setup.
+ //
+ // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
+ //
+ // For this test, for now, the use of specific bytes suffices.
+ return new byte[] {0x08, 0x03};
+ /*
+ return org.tensorflow.framework.RunOptions.newBuilder()
+ .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE)
+ .build()
+ .toByteArray();
+ */
+ }
+
+ public static byte[] sillyConfigProto() {
+ // Ideally this would use the generated Java sources for protocol buffers
+ // and end up with something like the snippet below. However, generating
+ // the Java files for the .proto files in tensorflow/core:protos_all is
+ // a bit cumbersome in bazel until the proto_library rule is setup.
+ //
+ // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362
+ // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558
+ //
+ // For this test, for now, the use of specific bytes suffices.
+ return new byte[] {0x10, 0x01, 0x28, 0x01};
+ /*
+ return org.tensorflow.framework.ConfigProto.newBuilder()
+ .setInterOpParallelismThreads(1)
+ .setIntraOpParallelismThreads(1)
+ .build()
+ .toByteArray();
+ */
+ }
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index 4e84886416..f984c508ee 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -24,7 +24,7 @@ public class TestUtil {
public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
+ public AutoCloseableList(Collection<? extends E> c) {
super(c);
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
index ca54214e06..7d3b26de8d 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow.op.core;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import java.io.ByteArrayOutputStream;
@@ -26,6 +27,7 @@ import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -37,6 +39,20 @@ import org.tensorflow.op.Scope;
@RunWith(JUnit4.class)
public class ConstantTest {
private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createInt() {
+ int value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Integer> op = Constant.create(scope, value);
+ try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) {
+ assertEquals(value, result.intValue());
+ }
+ }
+ }
@Test
public void createIntBuffer() {
@@ -47,10 +63,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints));
- Tensor<Integer> result = sess.runner().fetch(op.asOutput())
- .run().get(0).expect(Integer.class);
- int[] actual = new int[ints.length];
- assertArrayEquals(ints, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual));
+ }
+ }
+ }
+
+ @Test
+ public void createFloat() {
+ float value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Float> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Float.class).floatValue(), 0.0f);
+ }
}
}
@@ -63,9 +93,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
- Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class);
- float[] actual = new float[floats.length];
- assertArrayEquals(floats, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void createDouble() {
+ double value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Double> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Double.class).doubleValue(), 0.0);
+ }
}
}
@@ -78,9 +123,24 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
- Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class);
- double[] actual = new double[doubles.length];
- assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON);
+ }
+ }
+ }
+
+ @Test
+ public void createLong() {
+ long value = 1;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Long> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Long.class).longValue());
+ }
}
}
@@ -93,15 +153,29 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs));
- Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class);
- long[] actual = new long[longs.length];
- assertArrayEquals(longs, result.copyTo(actual));
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, result.expect(Long.class).copyTo(actual));
+ }
}
}
@Test
- public void createStringBuffer() throws IOException {
+ public void createBoolean() {
+ boolean value = true;
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant<Boolean> op = Constant.create(scope, value);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertEquals(value, result.expect(Boolean.class).booleanValue());
+ }
+ }
+ }
+ @Test
+ public void createStringBuffer() throws IOException {
byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
long[] shape = {};
@@ -124,8 +198,9 @@ public class ConstantTest {
Session sess = new Session(g)) {
Scope scope = new Scope(g);
Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content));
- Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class);
- assertArrayEquals(data, result.bytesValue());
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ assertArrayEquals(data, result.expect(String.class).bytesValue());
+ }
}
}
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
new file mode 100644
index 0000000000..3f49790b29
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java
@@ -0,0 +1,131 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import java.util.Arrays;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.Tensors;
+import org.tensorflow.TestUtil;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class GradientsTest {
+
+ @Test
+ public void createGradients() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(2, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithSum() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x));
+
+ assertNotNull(grads);
+ assertNotNull(grads.dy());
+ assertEquals(1, grads.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) {
+
+ assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void createGradientsWithInitialValues() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0));
+ Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy()));
+
+ assertNotNull(grads1);
+ assertNotNull(grads1.dy());
+ assertEquals(1, grads1.dy().size());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<>(
+ sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) {
+
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void validateGradientsNames() {
+ try (Graph g = new Graph()) {
+ Scope scope = new Scope(g).withSubScope("sub");
+
+ Output<Float> x = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> y = TestUtil.square(g, "y", x);
+
+ Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x));
+ assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/"));
+
+ Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x));
+ assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/"));
+ }
+ }
+}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
new file mode 100644
index 0000000000..cf3910b594
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java
@@ -0,0 +1,165 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+
+import java.util.List;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.Graph;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.Scope;
+import org.tensorflow.types.UInt8;
+
+@RunWith(JUnit4.class)
+public class ZerosTest {
+ private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createIntZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class);
+ try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) {
+ int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createFloatZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0f, actual[i][j], EPSILON);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createDoubleZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0.0, actual[i][j], EPSILON);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createLongZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0L, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createBooleanZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertFalse(actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void createUInt8Zeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class);
+ try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) {
+ byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]);
+ result.copyTo(actual);
+ for (int i = 0; i < actual.length; ++i) {
+ for (int j = 0; j < actual[i].length; ++j) {
+ assertEquals(0, actual[i][j]);
+ }
+ }
+ }
+ }
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void cannotCreateStringZeros() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros.create(scope, Constant.create(scope, shape), String.class);
+ }
+ }
+
+ @Test
+ public void operationsComposingZerosAreCorrectlyNamed() {
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ long[] shape = {2, 2};
+ Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class);
+ List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run();
+ }
+ }
+}