diff options
Diffstat (limited to 'tensorflow/java')
41 files changed, 1323 insertions, 494 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index b2b7ee3fa5..9dce78b9a3 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -295,6 +295,32 @@ tf_java_test( ], ) +tf_java_test( + name = "GradientsTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/GradientsTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.GradientsTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + +tf_java_test( + name = "ZerosTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.ZerosTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/maven/README.md b/tensorflow/java/maven/README.md index 3e030dcd09..cbc64a284f 100644 --- a/tensorflow/java/maven/README.md +++ b/tensorflow/java/maven/README.md @@ -151,16 +151,6 @@ conducted in a [Docker](https://www.docker.com) container. 7. Upon successful release, commit changes to all the `pom.xml` files (which should have the updated version number). -### Snapshots - -If the `TF_VERSION` provided to the `release.sh` script ends in `-SNAPSHOT`, -then instead of using official release files, the nightly build artifacts from -https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/, -https://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ and -https://ci.tensorflow.org/view/Nightly/job/nightly-android -will be used to upload to the Maven Central snapshots repository. (Note that -snapshots are only uploaded to Maven Central, not Bintray.) - ### Skip deploying to a repository Should you need, setting environment variables `DEPLOY_OSSRH=0` or @@ -173,12 +163,12 @@ cannot skip deploying to OSSRH for a `-SNAPSHOT` version. This section provides some pointers around how artifacts are currently assembled. -All native and java code is first built and tested on -a [Tensorflow Jenkins server](https://ci.tensorflow.org/) which run various -scripts under the [`tools/ci_build`](../../tools/ci_build/) directory. Of -particular interest may be `tools/ci_build/builds/libtensorflow.sh` which -bundles Java-related build sources and outputs into archives, and -`tools/ci_build/builds/android_full.sh` which produces an Android AAR package. +All native and java code is first built and tested by the release process +which run various scripts under the [`tools/ci_build`](../../tools/ci_build/) +directory. Of particular interest may be +`tools/ci_build/builds/libtensorflow.sh` which bundles Java-related build +sources and outputs into archives, and `tools/ci_build/builds/android_full.sh` +which produces an Android AAR package. Maven artifacts however are not created in Jenkins. Instead, artifacts are created and deployed externally on-demand, when a maintainer runs the diff --git a/tensorflow/java/maven/hadoop/pom.xml b/tensorflow/java/maven/hadoop/pom.xml index 7391dfb965..7fa751a46a 100644 --- a/tensorflow/java/maven/hadoop/pom.xml +++ b/tensorflow/java/maven/hadoop/pom.xml @@ -5,7 +5,7 @@ <groupId>org.tensorflow</groupId> <artifactId>hadoop</artifactId> <packaging>jar</packaging> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <name>tensorflow-hadoop</name> <url>https://www.tensorflow.org</url> <description>TensorFlow TFRecord InputFormat/OutputFormat for Apache Hadoop</description> diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index d44bdf8f81..8ecabfd399 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index e8925c6fb1..e03ce32216 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni</artifactId> diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 3bf4a2590c..fee840f547 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <relativePath>../</relativePath> </parent> <artifactId>libtensorflow_jni_gpu</artifactId> diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index b96dcf2888..0c33819b2b 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ <modelVersion>4.0.0</modelVersion> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <packaging>pom</packaging> <url>https://www.tensorflow.org</url> diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 5581d864d7..2af7a5cd2e 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <relativePath>../</relativePath> </parent> <artifactId>proto</artifactId> diff --git a/tensorflow/java/maven/run_inside_container.sh b/tensorflow/java/maven/run_inside_container.sh index 2240d6b7b9..f4794d68a9 100644 --- a/tensorflow/java/maven/run_inside_container.sh +++ b/tensorflow/java/maven/run_inside_container.sh @@ -26,12 +26,6 @@ TF_ECOSYSTEM_URL="https://github.com/tensorflow/ecosystem.git" DEPLOY_BINTRAY="${DEPLOY_BINTRAY:-true}" DEPLOY_OSSRH="${DEPLOY_OSSRH:-true}" -IS_SNAPSHOT="false" -if [[ "${TF_VERSION}" == *"-SNAPSHOT" ]]; then - IS_SNAPSHOT="true" - # Bintray does not allow snapshots. - DEPLOY_BINTRAY="false" -fi PROTOC_RELEASE_URL="https://github.com/google/protobuf/releases/download/v3.5.1/protoc-3.5.1-linux-x86_64.zip" if [[ "${DEPLOY_BINTRAY}" != "true" && "${DEPLOY_OSSRH}" != "true" ]]; then echo "Must deploy to at least one of Bintray or OSSRH" >&2 @@ -69,11 +63,7 @@ mvn_property() { } download_libtensorflow() { - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow-src.jar" - else - URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar" - fi + URL="${RELEASE_URL_PREFIX}/libtensorflow-src-${TF_VERSION}.jar" curl -L "${URL}" -o /tmp/src.jar cd "${DIR}/libtensorflow" jar -xvf /tmp/src.jar @@ -101,17 +91,9 @@ download_libtensorflow_jni() { mkdir windows-x86_64 mkdir darwin-x86_64 - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/ - # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64 - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=mac-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-darwin-x86_64.tar.gz" | tar -xvz -C darwin-x86_64 - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-cpu-windows-x86_64.zip" -o /tmp/windows.zip - else - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64 - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip - fi + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-darwin-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C darwin-x86_64 + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-cpu-windows-x86_64-${TF_VERSION}.zip" -o /tmp/windows.zip unzip /tmp/windows.zip -d windows-x86_64 rm -f /tmp/windows.zip @@ -129,13 +111,7 @@ download_libtensorflow_jni_gpu() { mkdir linux-x86_64 - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - # Nightly builds from http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/ - # and http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow-windows/ - curl -L "http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=gpu-linux/lastSuccessfulBuild/artifact/lib_package/libtensorflow_jni-gpu-linux-x86_64.tar.gz" | tar -xvz -C linux-x86_64 - else - curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 - fi + curl -L "${RELEASE_URL_PREFIX}/libtensorflow_jni-gpu-linux-x86_64-${TF_VERSION}.tar.gz" | tar -xvz -C linux-x86_64 # Updated timestamps seem to be required to get Maven to pick up the file. touch linux-x86_64/* @@ -165,11 +141,7 @@ generate_java_protos() { rm -f "/tmp/protoc.zip" # Download the release archive of TensorFlow protos. - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - URL="http://ci.tensorflow.org/view/Nightly/job/nightly-libtensorflow/TYPE=cpu-slave/lastSuccessfulBuild/artifact/lib_package/libtensorflow_proto.zip" - else - URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip" - fi + URL="${RELEASE_URL_PREFIX}/libtensorflow_proto-${TF_VERSION}.zip" curl -L "${URL}" -o /tmp/libtensorflow_proto.zip mkdir -p "${DIR}/proto/tmp/src" unzip -d "${DIR}/proto/tmp/src" "/tmp/libtensorflow_proto.zip" @@ -238,11 +210,7 @@ deploy_profile() { # Determine the correct pom file property to use # for the repository url. local rtype - if [[ "${IS_SNAPSHOT}" == "true" ]]; then - rtype='snapshotRepository' - else - rtype='repository' - fi + rtype='repository' local url=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.url") local repositoryId=$(mvn_property "${profile}" "project.distributionManagement.${rtype}.id") mvn gpg:sign-and-deploy-file \ @@ -300,17 +268,13 @@ mvn verify deploy_artifacts set +ex -if [[ "${IS_SNAPSHOT}" == "false" ]]; then - echo "Uploaded to the staging repository" - echo "After validating the release: " - if [[ "${DEPLOY_OSSRH}" == "true" ]]; then - echo "* Login to https://oss.sonatype.org/#stagingRepositories" - echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" - fi - if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then - echo "* Login to https://bintray.com/google/tensorflow/tensorflow" - echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort" - fi -else - echo "Uploaded to the snapshot repository" +echo "Uploaded to the staging repository" +echo "After validating the release: " +if [[ "${DEPLOY_OSSRH}" == "true" ]]; then + echo "* Login to https://oss.sonatype.org/#stagingRepositories" + echo "* Find the 'org.tensorflow' staging release and click either 'Release' to release or 'Drop' to abort" +fi +if [[ "${DEPLOY_BINTRAY}" == "true" ]]; then + echo "* Login to https://bintray.com/google/tensorflow/tensorflow" + echo "* Either 'Publish' unpublished items to release, or 'Discard' to abort" fi diff --git a/tensorflow/java/maven/spark-connector/pom.xml b/tensorflow/java/maven/spark-connector/pom.xml index 64956be02c..27d9b54c6c 100644 --- a/tensorflow/java/maven/spark-connector/pom.xml +++ b/tensorflow/java/maven/spark-connector/pom.xml @@ -6,7 +6,7 @@ <groupId>org.tensorflow</groupId> <artifactId>spark-connector_2.11</artifactId> <packaging>jar</packaging> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <name>spark-tensorflow-connector</name> <url>https://www.tensorflow.org</url> <description>TensorFlow TFRecord connector for Apache Spark DataFrames</description> diff --git a/tensorflow/java/maven/tensorflow-android/update.py b/tensorflow/java/maven/tensorflow-android/update.py index 2206d800ca..c620564072 100644 --- a/tensorflow/java/maven/tensorflow-android/update.py +++ b/tensorflow/java/maven/tensorflow-android/update.py @@ -86,19 +86,10 @@ def read_template(path): def main(): args = get_args() - # Artifacts are downloaded from the ci build. A SNAPSHOT release is - # associated with artifacts from the last successful nightly build. Otherwise, - # it comes from the officially blessed release artifacts. - if args.version.endswith('SNAPSHOT'): - info_url = ('https://ci.tensorflow.org/view/Nightly/job/nightly-android' - '/lastSuccessfulBuild/api/json') - aar_url = None - build_type = 'nightly-android' - else: - release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' - info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) - aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) - build_type = 'release-android' + release_prefix = 'https://storage.googleapis.com/tensorflow/libtensorflow' + info_url = '%s/android_buildinfo-%s.json' % (release_prefix, args.version) + aar_url = '%s/tensorflow-%s.aar' % (release_prefix, args.version) + build_type = 'release-android' # Retrieve build information build_info = get_json(info_url) diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 92e15aa2c7..c952545bc6 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ <parent> <groupId>org.tensorflow</groupId> <artifactId>parentpom</artifactId> - <version>1.9.0</version> + <version>1.10.0-rc1</version> <relativePath>../</relativePath> </parent> <artifactId>tensorflow</artifactId> diff --git a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java index 796d6a62dc..1b7bcdab35 100644 --- a/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java +++ b/tensorflow/java/src/gen/java/org/tensorflow/processor/OperatorProcessor.java @@ -290,7 +290,7 @@ public final class OperatorProcessor extends AbstractProcessor { javadoc.append(tag).append('\n'); } } - javadoc.append("@see {@link ").append(opClassName).append("}\n"); + javadoc.append("@see ").append(opClassName).append("\n"); return javadoc.toString(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 7b92be6d38..516655040b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -17,40 +17,54 @@ package org.tensorflow; import java.util.HashMap; import java.util.Map; + import org.tensorflow.types.UInt8; /** Represents the type of elements in a {@link Tensor} as an enum. */ public enum DataType { /** 32-bit single precision floating point. */ - FLOAT(1), + FLOAT(1, 4), /** 64-bit double precision floating point. */ - DOUBLE(2), + DOUBLE(2, 8), /** 32-bit signed integer. */ - INT32(3), + INT32(3, 4), /** 8-bit unsigned integer. */ - UINT8(4), + UINT8(4, 1), /** * A sequence of bytes. * * <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes. */ - STRING(7), + STRING(7, -1), /** 64-bit signed integer. */ - INT64(9), + INT64(9, 8), /** Boolean. */ - BOOL(10); + BOOL(10, 1); private final int value; + + private final int byteSize; - // The integer value must match the corresponding TF_* value in the TensorFlow C API. - DataType(int value) { + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param byteSize size of an element of this type, in bytes, -1 if unknown + */ + DataType(int value, int byteSize) { this.value = value; + this.byteSize = byteSize; + } + + /** + * Returns the size of an element of this type, in bytes, or -1 if element size is variable. + */ + public int byteSize() { + return byteSize; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index 7d19696749..752b49af04 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -144,21 +144,29 @@ public final class Graph implements AutoCloseable { } /** - * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, - * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} - * <p> - * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function - * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. - * <p> - * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all - * shapes in {@code y}. - * + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, i.e., + * {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + * + * <p>{@code dx} are used as initial gradients (which represent the symbolic partial derivatives + * of some loss function {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of + * {@code y}. + * + * <p>If {@code dx} is null, the implementation will use dx of {@link + * org.tensorflow.op.core.OnesLike OnesLike} for all shapes in {@code y}. + * + * <p>{@code prefix} is used as the name prefix applied to all nodes added to the graph to compute + * gradients. It must be unique within the provided graph or the operation will fail. + * + * <p>If {@code prefix} is null, then one will be chosen automatically. + * + * @param prefix unique string prefix applied before the names of nodes added to the graph to + * compute gradients. If null, a default one will be chosen. * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return the partial derivatives {@code dy} with the size of {@code x} */ - public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) { + public Output<?>[] addGradients(String prefix, Output<?>[] y, Output<?>[] x, Output<?>[] dx) { Output<?>[] dy = new Output<?>[x.length]; final long[] yHandles = new long[y.length]; final int[] yIndices = new int[y.length]; @@ -185,12 +193,21 @@ public final class Graph implements AutoCloseable { dxIndices[i] = dx[i].index(); } } - // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles - // of the gradient operations while the second holds the index of their output - // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // Gradient outputs are returned in two continuous arrays concatenated into one. The first + // holds the native handles of the gradient operations while the second holds the index of + // their output e.g. given + // xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] long[] dyHandlesAndIndices = - addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + addGradients( + ref.nativeHandle(), + prefix, + yHandles, + yIndices, + xHandles, + xIndices, + dxHandles, + dxIndices); int ndy = dyHandlesAndIndices.length >> 1; if (ndy != dy.length) { throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length @@ -207,16 +224,16 @@ public final class Graph implements AutoCloseable { /** * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, * i.e., {@code dy/dx_1, dy/dx_2...} - * <p> + * <p> * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is - * a single output and {@code dx} is null. - * + * a single output, {@code dx} is null and {@code prefix} is null. + * * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed * @return the partial derivatives {@code dy} with the size of {@code x} */ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) { - return addGradients(new Output<?>[]{y}, x, null); + return addGradients(null, new Output<?>[] {y}, x, null); } private final Object nativeHandleLock = new Object(); @@ -330,8 +347,15 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); - private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, - long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + private static native long[] addGradients( + long handle, + String prefix, + long[] inputHandles, + int[] inputIndices, + long[] outputHandles, + int[] outputIndices, + long[] gradInputHandles, + int[] gradInputIndices); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java index c8b9126f03..49594e6b47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow/java/src/main/java/org/tensorflow/SavedModelBundle.java @@ -25,18 +25,86 @@ package org.tensorflow; * protocol buffer</a>). */ public class SavedModelBundle implements AutoCloseable { + /** Options for loading a SavedModel. */ + public static final class Loader { + /** Load a <code>SavedModelBundle</code> with the configured options. */ + public SavedModelBundle load() { + return SavedModelBundle.load(exportDir, tags, configProto, runOptions); + } + + /** + * Sets options to use when executing model initialization operations. + * + * @param options Serialized <a + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions + * protocol buffer</a>. + */ + public Loader withRunOptions(byte[] options) { + this.runOptions = options; + return this; + } + + /** + * Set configuration of the <code>Session</code> object created when loading the model. + * + * @param configProto Serialized <a + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto + * protocol buffer</a>. + */ + public Loader withConfigProto(byte[] configProto) { + this.configProto = configProto; + return this; + } + + /** + * Sets the set of tags that identify the specific graph in the saved model to load. + * + * @param tags the tags identifying the specific MetaGraphDef to load. + */ + public Loader withTags(String... tags) { + this.tags = tags; + return this; + } + + private Loader(String exportDir) { + this.exportDir = exportDir; + } + + private String exportDir = null; + private String[] tags = null; + private byte[] configProto = null; + private byte[] runOptions = null; + } /** * Load a saved model from an export directory. The model that is being loaded should be created * using the <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model">Saved Model * API</a>. * + * <p>This method is a shorthand for: + * + * <pre>{@code + * SavedModelBundler.loader().withTags(tags).load(); + * }</pre> + * * @param exportDir the directory path containing a saved model. * @param tags the tags identifying the specific metagraphdef to load. * @return a bundle containing the graph and associated session. */ public static SavedModelBundle load(String exportDir, String... tags) { - return load(exportDir, tags, null); + return loader(exportDir).withTags(tags).load(); + } + + /** + * Load a saved model. + * + * <p/>Returns a <code>Loader</code> object that can set configuration options before actually + * loading the model, + * + * @param exportDir the directory path containing a saved model. + */ + public static Loader loader(String exportDir) { + return new Loader(exportDir); } /** @@ -95,7 +163,8 @@ public class SavedModelBundle implements AutoCloseable { return new SavedModelBundle(graph, session, metaGraphDef); } - private static native SavedModelBundle load(String exportDir, String[] tags, byte[] runOptions); + private static native SavedModelBundle load( + String exportDir, String[] tags, byte[] config, byte[] runOptions); static { TensorFlow.init(); diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index 73324f23e6..a660d25f98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -185,11 +185,20 @@ public final class Session implements AutoCloseable { return this; } - /** Makes {@link #run()} return the Tensor referred to by {@code output}. */ + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + */ public Runner fetch(Output<?> output) { outputs.add(output); return this; } + + /** + * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + */ + public Runner fetch(Operand<?> operand) { + return fetch(operand.asOutput()); + } /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. @@ -209,6 +218,13 @@ public final class Session implements AutoCloseable { targets.add(operation); return this; } + + /** + * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s. + */ + public Runner addTarget(Operand<?> operand) { + return addTarget(operand.asOutput().op()); + } /** * (Experimental method): set options (typically for debugging) for this run. diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 24a3775db6..8987253768 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -595,20 +595,11 @@ public final class Tensor<T> implements AutoCloseable { } private static int elemByteSize(DataType dataType) { - switch (dataType) { - case FLOAT: - case INT32: - return 4; - case DOUBLE: - case INT64: - return 8; - case BOOL: - case UINT8: - return 1; - case STRING: + int size = dataType.byteSize(); + if (size < 0) { throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } - throw new IllegalArgumentException("DataType " + dataType + " is not supported yet"); + return size; } private static void throwExceptionIfNotByteOfByteArrays(Object array) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java index 8de2eaeb79..5a233bcc98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/Scope.java @@ -135,17 +135,8 @@ public final class Scope { * }</pre> * * <p><b>Note:</b> if you provide a composite operator building class (i.e, a class that adds a - * set of related operations to the graph by calling other operator building code) you should also - * create a {@link #withSubScope(String)} scope for the underlying operators to group them under a - * meaningful name. - * - * <pre>{@code - * public static Stddev create(Scope scope, ...) { - * // group sub-operations under a common name - * Scope group = scope.withSubScope("stddev"); - * ... Sqrt.create(group, Mean.create(group, ...)) - * } - * }</pre> + * set of related operations to the graph by calling other operator building code), the provided + * name will act as a subscope to all underlying operators. * * @param defaultName name for the underlying operator. * @return unique name for the operator. diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index de4049f66b..00b6726be3 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -15,11 +15,15 @@ limitations under the License. package org.tensorflow.op.core; +import static java.nio.charset.StandardCharsets.UTF_8; + import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.nio.charset.Charset; + import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -32,25 +36,82 @@ import org.tensorflow.op.annotation.Operator; /** An operator producing a constant value. */ @Operator public final class Constant<T> extends PrimitiveOp implements Operand<T> { + /** - * Create a constant from a Java object. + * Creates a constant containing a single {@code int} element. * - * <p>The argument {@code object} is first converted into a Tensor using {@link - * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be - * provided. For example: + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return an integer constant + */ + public static Constant<Integer> create(Scope scope, int data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-1 constant of {@code int} elements. * - * <pre>{@code - * Constant.create(scope, 7); // returns a constant scalar tensor 7 - * }</pre> + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param object a Java object representing the constant. - * @see org.tensorflow.Tensor#create(Object) Tensor.create + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. */ - public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { - try (Tensor<T> value = Tensor.create(object, type)) { - return createWithTensor(scope, value); - } + public static Constant<Integer> create(Scope scope, int[][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-3 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-4 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-5 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-6 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][][][] data) { + return create(scope, data, Integer.class); } /** @@ -64,6 +125,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return an integer constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) { @@ -73,6 +135,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code float} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a float constant + */ + public static Constant<Float> create(Scope scope, float data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-1 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-2 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-3 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-4 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-5 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-6 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][][][] data) { + return create(scope, data, Float.class); + } + + /** * Create a {@link DataType#FLOAT} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -83,6 +222,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) { @@ -92,6 +232,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code double} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a double constant + */ + public static Constant<Double> create(Scope scope, double data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-2 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-3 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-4 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-5 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-6 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][][][] data) { + return create(scope, data, Double.class); + } + + /** * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -102,6 +319,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) { @@ -111,6 +329,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code long} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a long constant + */ + public static Constant<Long> create(Scope scope, long data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-1 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-2 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-3 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-4 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-5 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-6 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][][][] data) { + return create(scope, data, Long.class); + } + + /** * Create a {@link DataType#INT64} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -121,6 +416,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) { @@ -130,6 +426,174 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code boolean} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a boolean constant + */ + public static Constant<Boolean> create(Scope scope, boolean data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-1 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-2 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-3 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-4 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-5 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-6 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a {@code String} constant using the default, UTF-8 encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant<String> create(Scope scope, String data) { + return create(scope, data, UTF_8); + } + + /** + * Creates a {@code String} constant using a specified encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param charset The encoding from String to bytes. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant<String> create(Scope scope, String data, Charset charset) { + try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) { + return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + } + } + + /** + * Creates a constant containing a single {@code String} element, represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-1 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-2 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-3 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-4 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-5 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][][][] data) { + return create(scope, data, String.class); + } + + /** * Create a constant with data from the given buffer. * * <p>Creates a Constant with the provided shape of any type where the constant data has been @@ -141,6 +605,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param type the tensor datatype. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ @@ -150,6 +615,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } } + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @return a constant of type `type` + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { + try (Tensor<T> value = Tensor.create(object, type)) { + return createWithTensor(scope, value); + } + } + private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) { return new Constant<T>( scope diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java index f4671c8af9..eea9dc1c47 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -18,7 +18,6 @@ package org.tensorflow.op.core; import java.util.Arrays; import java.util.Iterator; import java.util.List; - import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; @@ -54,32 +53,36 @@ public class Gradients implements Op, Iterable<Operand<?>> { * Optional attributes for {@link Gradients} */ public static class Options { - + /** * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return this option builder */ - public Options dx(Iterable<Operand<?>> dx) { + public Options dx(Iterable<? extends Operand<?>> dx) { this.dx = dx; return this; } - - private Iterable<Operand<?>> dx; - + + private Iterable<? extends Operand<?>> dx; + private Options() { } } /** * Adds gradients computation ops to the graph according to scope. - * + * * @param scope current graph scope * @param y outputs of the function to derive * @param x inputs of the function for which partial derivatives are computed * @param options carries optional attributes values * @return a new instance of {@code Gradients} */ - public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, + Iterable<? extends Operand<?>> y, + Iterable<? extends Operand<?>> x, + Options... options) { Output<?>[] dx = null; if (options != null) { for (Options opts : options) { @@ -88,16 +91,20 @@ public class Gradients implements Op, Iterable<Operand<?>> { } } } - Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); - return new Gradients(Arrays.asList(gradOutputs)); + Output<?>[] dy = + scope + .graph() + .addGradients( + scope.makeOpName("Gradients"), Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(dy)); } /** * Adds gradients computation ops to the graph according to scope. - * - * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is - * a single output. - * + * + * <p>This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where + * {@code y} is a single output. + * * @param scope current graph scope * @param y output of the function to derive * @param x inputs of the function for which partial derivatives are computed @@ -105,7 +112,8 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @return a new instance of {@code Gradients} */ @SuppressWarnings({"unchecked", "rawtypes"}) - public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) { + public static Gradients create( + Scope scope, Operand<?> y, Iterable<? extends Operand<?>> x, Options... options) { return create(scope, (Iterable) Arrays.asList(y), x, options); } @@ -113,7 +121,7 @@ public class Gradients implements Op, Iterable<Operand<?>> { * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} * @return builder to add more options to this operation */ - public Options dx(Iterable<Operand<?>> dx) { + public static Options dx(Iterable<? extends Operand<?>> dx) { return new Options().dx(dx); } @@ -129,13 +137,13 @@ public class Gradients implements Op, Iterable<Operand<?>> { public List<Output<?>> dy() { return dy; } - + /** * Returns a symbolic handle to one of the gradient operation output - * <p> - * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * + * <p>Warning: Does not check that the type of the tensor matches T. It is recommended to call * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code - * gradients.<Integer>dy(0)} + * gradients.<Float>dy(0)} * * @param <T> The expected element type of the tensors produced by this output. * @param index The index of the output among the gradients added by this operation diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java new file mode 100644 index 0000000000..b7c6beb9bc --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * An operator creating a constant initialized with zeros of the shape given by `dims`. + * + * <p>For example, the following expression + * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre> + * is the equivalent of + * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre> + * + * @param <T> constant type + */ +@Operator +public class Zeros<T> implements Op, Operand<T> { + + /** + * Creates a zeroed tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype + * @return a constant tensor initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. + */ + public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) { + Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros" + int zeroSize = DataType.fromClass(type).byteSize(); + if (zeroSize < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros"); + } + Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize)); + return new Zeros<T>(Fill.create(childScope, dims, zero)); + } + + @Override + public Output<T> asOutput() { + return fill.asOutput(); + } + + private final Fill<T> fill; + + private Zeros(Fill<T> fill) { + this.fill = fill; + } +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java deleted file mode 100644 index ab34f6aa12..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a boolean. */ -public class TFBool implements TFType { - private TFBool() {} - static { - Types.typeCodes.put(TFBool.class, DataType.BOOL); - } - static { - Types.scalars.put(TFBool.class, false); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java deleted file mode 100644 index 49e5d9f2f3..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 64-bit double precision floating point number. */ -public class TFDouble implements TFType { - private TFDouble() {} - static { - Types.typeCodes.put(TFDouble.class, DataType.DOUBLE); - } - static { - Types.scalars.put(TFDouble.class, 0.0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java deleted file mode 100644 index 8426ee41f0..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 32-bit single precision floating point number. */ -public class TFFloat implements TFType { - private TFFloat() {} - static { - Types.typeCodes.put(TFFloat.class, DataType.FLOAT); - } - static { - Types.scalars.put(TFFloat.class, 0f); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java deleted file mode 100644 index 3947b6ad09..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 32-bit signed integer. */ -public class TFInt32 implements TFType { - private TFInt32() {} - static { - Types.typeCodes.put(TFInt32.class, DataType.INT32); - } - static { - Types.scalars.put(TFInt32.class, 0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java deleted file mode 100644 index ccdded8693..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 64-bit signed integer. */ -public class TFInt64 implements TFType { - private TFInt64() {} - static { - Types.typeCodes.put(TFInt64.class, DataType.INT64); - } - static { - Types.scalars.put(TFInt64.class, 0L); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java deleted file mode 100644 index e7327e8c57..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents an arbitrary sequence of bytes. */ -public class TFString implements TFType { - private TFString() {} - static { - Types.typeCodes.put(TFString.class, DataType.STRING); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java deleted file mode 100644 index 562953ac9d..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.types; - -/** - * A marker interface for classes representing TensorFlow types. - */ -public interface TFType {} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java deleted file mode 100644 index d7305ca5a8..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents an 8-bit unsigned integer. */ -public class TFUInt8 implements TFType { - private TFUInt8() {} - static { - Types.typeCodes.put(TFUInt8.class, DataType.UINT8); - } - static { - Types.scalars.put(TFUInt8.class, (byte)0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java deleted file mode 100644 index 976cd9fd34..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.types; - -import java.util.HashMap; -import java.util.Map; -import org.tensorflow.DataType; - -/** - * Utility class for managing the representation of TensorFlow types as Java - * types. For each TensorFlow type (e.g., int32), there is a corresponding Java - * type (e.g., TFInt32) that represents it at compile time and a corresponding - * class object (e.g., TFInt32.class) that represents it at run time. There is - * also an enumeration value in DataType that can be used to represent the - * type, though that should rarely be required. - */ -public class Types { - - private Types() {} // not instantiable - - static final Map<Class<?>, DataType> typeCodes = new HashMap<>(); - - /** Returns the DataType value corresponding to a TensorFlow type class. */ - public static DataType dataType(Class<? extends TFType> c) { - DataType dtype = typeCodes.get(c); - if (dtype == null) { - throw new IllegalArgumentException("" + c + " is not a TensorFlow type."); - } - return dtype; - } - - static final Map<Class<?>, Object> scalars = new HashMap<>(); - - /** Returns the zero value of type described by {@code c}, or null if - * the type (e.g., string) is not numeric and therefore has no zero value. - */ - public static Object zeroValue(Class<? extends TFType> c) { - return scalars.get(c); - } -} diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index dac6a345e9..f1744d8769 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -133,12 +133,10 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { return ret; } -JNIEXPORT jlongArray JNICALL -Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, - jlongArray y_handles, jintArray y_indices, - jlongArray x_handles, jintArray x_indices, - jlongArray dx_handles, jintArray dx_indices) { - +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv* env, jclass clazz, jlong handle, jstring prefix, + jlongArray y_handles, jintArray y_indices, jlongArray x_handles, + jintArray x_indices, jlongArray dx_handles, jintArray dx_indices) { TF_Graph* g = requireHandle(env, handle); if (g == nullptr) return nullptr; @@ -163,9 +161,16 @@ Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, } if (env->ExceptionCheck()) return nullptr; + const char* cprefix = nullptr; + if (prefix != nullptr) { + cprefix = env->GetStringUTFChars(prefix, nullptr); + } TF_Status* status = TF_NewStatus(); - TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); - + TF_AddGradientsWithPrefix(g, cprefix, y.get(), ny, x.get(), nx, dx.get(), + status, dy.get()); + if (prefix != nullptr) { + env->ReleaseStringUTFChars(prefix, cprefix); + } if (!throwExceptionIfNotOK(env, status)) { TF_DeleteStatus(status); return nullptr; diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index 4f87e8d5a7..215695cdfd 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -76,11 +76,11 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, /* * Class: org_tensorflow_Graph * Method: name - * Signature: (J[J[I[J[I[J[I)[J + * Signature: (JLjava/lang/String;[J[I[J[I[J[I)[J */ -JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, - jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, - jintArray); +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients( + JNIEnv *, jclass, jlong, jstring, jlongArray, jintArray, jlongArray, + jintArray, jlongArray, jintArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc index de6382a79c..68999fb2da 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.cc +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.cc @@ -22,12 +22,25 @@ limitations under the License. JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( JNIEnv* env, jclass clazz, jstring export_dir, jobjectArray tags, - jbyteArray run_options) { + jbyteArray config, jbyteArray run_options) { TF_Status* status = TF_NewStatus(); jobject bundle = nullptr; // allocate parameters for TF_LoadSessionFromSavedModel TF_SessionOptions* opts = TF_NewSessionOptions(); + if (config != nullptr) { + size_t sz = env->GetArrayLength(config); + if (sz > 0) { + jbyte* config_data = env->GetByteArrayElements(config, nullptr); + TF_SetConfig(opts, static_cast<void*>(config_data), sz, status); + env->ReleaseByteArrayElements(config, config_data, JNI_ABORT); + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteSessionOptions(opts); + TF_DeleteStatus(status); + return nullptr; + } + } + } TF_Buffer* crun_options = nullptr; if (run_options != nullptr) { size_t sz = env->GetArrayLength(run_options); diff --git a/tensorflow/java/src/main/native/saved_model_bundle_jni.h b/tensorflow/java/src/main/native/saved_model_bundle_jni.h index 6cce6a81bd..a4b05d0409 100644 --- a/tensorflow/java/src/main/native/saved_model_bundle_jni.h +++ b/tensorflow/java/src/main/native/saved_model_bundle_jni.h @@ -26,10 +26,10 @@ extern "C" { * Class: org_tensorflow_SavedModelBundle * Method: load * Signature: - * (Ljava/lang/String;[Ljava/lang/String;[B)Lorg/tensorflow/SavedModelBundle; + * (Ljava/lang/String;[Ljava/lang/String;[B;[B)Lorg/tensorflow/SavedModelBundle; */ JNIEXPORT jobject JNICALL Java_org_tensorflow_SavedModelBundle_load( - JNIEnv *, jclass, jstring, jobjectArray, jbyteArray); + JNIEnv *, jclass, jstring, jobjectArray, jbyteArray, jbyteArray); #ifdef __cplusplus } // extern "C" diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c2e52c22c6..7c05c1deaf 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,7 +22,6 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; - import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -180,8 +179,8 @@ public class GraphTest { Output<Float> x = TestUtil.placeholder(g, "x", Float.class); Output<Float> y0 = TestUtil.square(g, "y0", x); Output<Float> y1 = TestUtil.square(g, "y1", y0); - - Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + + Output<?>[] grad = g.addGradients(null, toArray(y0, y1), toArray(x), null); assertNotNull(grad); assertEquals(1, grad.length); assertEquals(DataType.FLOAT, grad[0].dataType()); @@ -212,7 +211,7 @@ public class GraphTest { assertEquals(1, grad0.length); assertEquals(DataType.FLOAT, grad0[0].dataType()); - Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), toArray(grad0[0])); assertNotNull(grad1); assertEquals(1, grad1.length); assertEquals(DataType.FLOAT, grad1[0].dataType()); @@ -228,6 +227,33 @@ public class GraphTest { } } } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + + Output<Float> x = TestUtil.placeholder(g, "x", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + + Output<?>[] grad0 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad0[0].op().name().startsWith("gradients/")); + + Output<?>[] grad1 = g.addGradients(null, toArray(y0), toArray(x), null); + assertTrue(grad1[0].op().name().startsWith("gradients_1/")); + + Output<?>[] grad2 = g.addGradients("more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad2[0].op().name().startsWith("more_gradients/")); + + Output<?>[] grad3 = g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + assertTrue(grad3[0].op().name().startsWith("even_more_gradients/")); + + try { + g.addGradients("even_more_gradients", toArray(y0), toArray(x), null); + } catch (IllegalArgumentException e) { + // expected exception + } + } + } private static Output<?>[] toArray(Output<?>... outputs) { return outputs; diff --git a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java index 7922f3329c..7d936867a7 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -47,7 +47,61 @@ public class SavedModelBundleTest { fail("not expected"); } catch (org.tensorflow.TensorFlowException e) { // expected exception - assertTrue(e.getMessage().contains("SavedModel not found")); + assertTrue(e.getMessage().contains("Could not find SavedModel")); } } + + @Test + public void loader() { + try (SavedModelBundle bundle = SavedModelBundle.loader(SAVED_MODEL_PATH) + .withTags("serve") + .withConfigProto(sillyConfigProto()) + .withRunOptions(sillyRunOptions()) + .load()) { + assertNotNull(bundle.session()); + assertNotNull(bundle.graph()); + assertNotNull(bundle.metaGraphDef()); + } + } + + private static byte[] sillyRunOptions() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x08, 0x03}; + /* + return org.tensorflow.framework.RunOptions.newBuilder() + .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) + .build() + .toByteArray(); + */ + } + + public static byte[] sillyConfigProto() { + // Ideally this would use the generated Java sources for protocol buffers + // and end up with something like the snippet below. However, generating + // the Java files for the .proto files in tensorflow/core:protos_all is + // a bit cumbersome in bazel until the proto_library rule is setup. + // + // See https://github.com/bazelbuild/bazel/issues/52#issuecomment-194341866 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251515362 + // https://github.com/bazelbuild/rules_go/pull/121#issuecomment-251692558 + // + // For this test, for now, the use of specific bytes suffices. + return new byte[] {0x10, 0x01, 0x28, 0x01}; + /* + return org.tensorflow.framework.ConfigProto.newBuilder() + .setInterOpParallelismThreads(1) + .setIntraOpParallelismThreads(1) + .build() + .toByteArray(); + */ + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index 4e84886416..f984c508ee 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -24,7 +24,7 @@ public class TestUtil { public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E> implements AutoCloseable { - AutoCloseableList(Collection<? extends E> c) { + public AutoCloseableList(Collection<? extends E> c) { super(c); } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index ca54214e06..7d3b26de8d 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.op.core; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; @@ -26,6 +27,7 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,6 +39,20 @@ import org.tensorflow.op.Scope; @RunWith(JUnit4.class) public class ConstantTest { private static final float EPSILON = 1e-7f; + + @Test + public void createInt() { + int value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Integer> op = Constant.create(scope, value); + try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { + assertEquals(value, result.intValue()); + } + } + } @Test public void createIntBuffer() { @@ -47,10 +63,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor<Integer> result = sess.runner().fetch(op.asOutput()) - .run().get(0).expect(Integer.class); - int[] actual = new int[ints.length]; - assertArrayEquals(ints, result.copyTo(actual)); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); + } + } + } + + @Test + public void createFloat() { + float value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Float> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); + } } } @@ -63,9 +93,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); - Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[] actual = new float[floats.length]; - assertArrayEquals(floats, result.copyTo(actual), EPSILON); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createDouble() { + double value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Double> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); + } } } @@ -78,9 +123,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); - Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[] actual = new double[doubles.length]; - assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createLong() { + long value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Long> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Long.class).longValue()); + } } } @@ -93,15 +153,29 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs)); - Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[] actual = new long[longs.length]; - assertArrayEquals(longs, result.copyTo(actual)); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); + } } } @Test - public void createStringBuffer() throws IOException { + public void createBoolean() { + boolean value = true; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Boolean> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Boolean.class).booleanValue()); + } + } + } + @Test + public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; long[] shape = {}; @@ -124,8 +198,9 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); - Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class); - assertArrayEquals(data, result.bytesValue()); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertArrayEquals(data, result.expect(String.class).bytesValue()); + } } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java new file mode 100644 index 0000000000..3f49790b29 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -0,0 +1,131 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.Tensors; +import org.tensorflow.TestUtil; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class GradientsTest { + + @Test + public void createGradients() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, y1, Arrays.asList(x, y0)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(2, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(18.0f, outputs.get(1).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithSum() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads = Gradients.create(scope, Arrays.asList(y0, y1), Arrays.asList(x)); + + assertNotNull(grads); + assertNotNull(grads.dy()); + assertEquals(1, grads.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + + assertEquals(114.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void createGradientsWithInitialValues() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y0 = TestUtil.square(g, "y0", x); + Output<Float> y1 = TestUtil.square(g, "y1", y0); + + Gradients grads0 = Gradients.create(scope, y1, Arrays.asList(y0)); + Gradients grads1 = Gradients.create(scope, y0, Arrays.asList(x), Gradients.dx(grads0.dy())); + + assertNotNull(grads1); + assertNotNull(grads1.dy()); + assertEquals(1, grads1.dy().size()); + + try (Tensor<Float> c = Tensors.create(3.0f); + TestUtil.AutoCloseableList<Tensor<?>> outputs = + new TestUtil.AutoCloseableList<>( + sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + } + } + } + + @Test + public void validateGradientsNames() { + try (Graph g = new Graph()) { + Scope scope = new Scope(g).withSubScope("sub"); + + Output<Float> x = TestUtil.placeholder(g, "x1", Float.class); + Output<Float> y = TestUtil.square(g, "y", x); + + Gradients grad0 = Gradients.create(scope, y, Arrays.asList(x)); + assertTrue(grad0.dy(0).op().name().startsWith("sub/Gradients/")); + + Gradients grad1 = Gradients.create(scope.withName("MyGradients"), y, Arrays.asList(x)); + assertTrue(grad1.dy(0).op().name().startsWith("sub/MyGradients/")); + } + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java new file mode 100644 index 0000000000..cf3910b594 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.util.List; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; +import org.tensorflow.types.UInt8; + +@RunWith(JUnit4.class) +public class ZerosTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test + public void createFloatZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createDoubleZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createLongZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0L, actual[i][j]); + } + } + } + } + } + + @Test + public void createBooleanZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertFalse(actual[i][j]); + } + } + } + } + } + + @Test + public void createUInt8Zeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); + result.copyTo(actual); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test(expected = IllegalArgumentException.class) + public void cannotCreateStringZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros.create(scope, Constant.create(scope, shape), String.class); + } + } + + @Test + public void operationsComposingZerosAreCorrectlyNamed() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); + List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + } + } +} |