aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-07-02 17:07:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-02 17:10:57 -0700
commit73e38c29c74d9d9bf7128bf4737a410ff005611e (patch)
treef84c84429850d1b38cb4c0f0df24aadfefc7db8e /tensorflow/java
parenteacdfdf6c0353ac0578afbd962dbbafa6121c28f (diff)
Merge changes from github.
PiperOrigin-RevId: 203037623
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/gen/cc/source_writer.cc1
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java79
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java153
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc54
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h9
-rw-r--r--tensorflow/java/src/main/native/session_jni.cc32
-rw-r--r--tensorflow/java/src/main/native/utils_jni.cc53
-rw-r--r--tensorflow/java/src/main/native/utils_jni.h33
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java103
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/SessionTest.java38
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/TestUtil.java34
11 files changed, 526 insertions, 63 deletions
diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc
index 66401bdba7..8e5fba7e32 100644
--- a/tensorflow/java/src/gen/cc/source_writer.cc
+++ b/tensorflow/java/src/gen/cc/source_writer.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <string>
#include <algorithm>
#include <list>
#include <string>
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index d4fd3db5f7..7d19696749 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable {
}
}
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function
+ * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}.
+ * <p>
+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all
+ * shapes in {@code y}.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ Output<?>[] dy = new Output<?>[x.length];
+ final long[] yHandles = new long[y.length];
+ final int[] yIndices = new int[y.length];
+ final long[] xHandles = new long[x.length];
+ final int[] xIndices = new int[x.length];
+ long[] dxHandles = null;
+ int[] dxIndices = null;
+
+ try (Reference ref = ref()) {
+ for (int i = 0; i < y.length; ++i) {
+ yHandles[i] = y[i].op().getUnsafeNativeHandle();
+ yIndices[i] = y[i].index();
+ }
+ for (int i = 0; i < x.length; ++i) {
+ xHandles[i] = x[i].op().getUnsafeNativeHandle();
+ xIndices[i] = x[i].index();
+ }
+ if (dx != null && dx.length > 0) {
+ dxHandles = new long[dx.length];
+ dxIndices = new int[dx.length];
+
+ for (int i = 0; i < dx.length; ++i) {
+ dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
+ dxIndices[i] = dx[i].index();
+ }
+ }
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
+ // of the gradient operations while the second holds the index of their output
+ // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
+ long[] dyHandlesAndIndices =
+ addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ int ndy = dyHandlesAndIndices.length >> 1;
+ if (ndy != dy.length) {
+ throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length
+ + " were expected");
+ }
+ for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
+ Operation op = new Operation(this, dyHandlesAndIndices[i]);
+ dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
+ }
+ }
+ return dy;
+ }
+
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code dy/dx_1, dy/dx_2...}
+ * <p>
+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is
+ * a single output and {@code dx} is null.
+ *
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?> y, Output<?>[] x) {
+ return addGradients(new Output<?>[]{y}, x, null);
+ }
+
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;
@@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
+ private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+
static {
TensorFlow.init();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
new file mode 100644
index 0000000000..f4671c8af9
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java
@@ -0,0 +1,153 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Operands;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss
+ * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}.
+ * <p>
+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all
+ * shapes in {@code y}.
+ * <p>
+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}.
+ * <p>
+ * Example of usage:
+ * <pre>{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ *
+ * Constant<Float> alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
+ * }</pre>
+ */
+@Operator
+public class Gradients implements Op, Iterable<Operand<?>> {
+
+ /**
+ * Optional attributes for {@link Gradients}
+ */
+ public static class Options {
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return this option builder
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ this.dx = dx;
+ return this;
+ }
+
+ private Iterable<Operand<?>> dx;
+
+ private Options() {
+ }
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * @param scope current graph scope
+ * @param y outputs of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ public static Gradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ Output<?>[] dx = null;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.dx != null) {
+ dx = Operands.asOutputs(opts.dx);
+ }
+ }
+ }
+ Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new Gradients(Arrays.asList(gradOutputs));
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is
+ * a single output.
+ *
+ * @param scope current graph scope
+ * @param y output of the function to derive
+ * @param x inputs of the function for which partial derivatives are computed
+ * @param options carries optional attributes values
+ * @return a new instance of {@code Gradients}
+ */
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public static Gradients create(Scope scope, Operand<?> y, Iterable<Operand<?>> x, Options... options) {
+ return create(scope, (Iterable) Arrays.asList(y), x, options);
+ }
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return builder to add more options to this operation
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ return new Options().dx(dx);
+ }
+
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public Iterator<Operand<?>> iterator() {
+ return (Iterator) dy.iterator();
+ }
+
+ /**
+ * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x}
+ */
+ public List<Output<?>> dy() {
+ return dy;
+ }
+
+ /**
+ * Returns a symbolic handle to one of the gradient operation output
+ * <p>
+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * gradients.<Integer>dy(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param index The index of the output among the gradients added by this operation
+ */
+ @SuppressWarnings("unchecked")
+ public <T> Output<T> dy(int index) {
+ return (Output<T>) dy.get(index);
+ }
+
+ private List<Output<?>> dy;
+
+ private Gradients(List<Output<?>> dy) {
+ this.dy = dy;
+ }
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 0fef155275..dac6a345e9 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/java/src/main/native/graph_jni.h"
#include <limits>
+#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
@@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
TF_DeleteBuffer(buf);
return ret;
}
+
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
+ jlongArray y_handles, jintArray y_indices,
+ jlongArray x_handles, jintArray x_indices,
+ jlongArray dx_handles, jintArray dx_indices) {
+
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return nullptr;
+
+ const jint ny = env->GetArrayLength(y_handles);
+ const jint nx = env->GetArrayLength(x_handles);
+
+ std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
+ std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
+ std::unique_ptr<TF_Output[]> dx(nullptr);
+ std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
+
+ resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
+ resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
+ if (dx_handles != nullptr) {
+ if (env->GetArrayLength(dx_handles) != ny) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d dx handles", ny,
+ env->GetArrayLength(dx_handles));
+ }
+ dx.reset(new TF_Output[ny]);
+ resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
+ }
+ if (env->ExceptionCheck()) return nullptr;
+
+ TF_Status* status = TF_NewStatus();
+ TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
+
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ TF_DeleteStatus(status);
+
+ // returned array contains both op handles and output indices, in pair
+ jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
+ jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
+ for (int i = 0, j = nx; i < nx; ++i, ++j) {
+ TF_Output dy_output = dy.get()[i];
+ dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
+ dy_elems[j] = static_cast<jlong>(dy_output.index);
+ }
+ env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
+
+ return dy_handles_and_indices;
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index dd2e038332..4f87e8d5a7 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
jclass,
jlong);
+/*
+ * Class: org_tensorflow_Graph
+ * Method: name
+ * Signature: (J[J[I[J[I[J[I)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
+ jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
+ jintArray);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc
index 2cd542d3c9..cb54daf137 100644
--- a/tensorflow/java/src/main/native/session_jni.cc
+++ b/tensorflow/java/src/main/native/session_jni.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/session_jni.h"
@@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
}
-void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
- jintArray src_index, TF_Output* dst, jint n) {
- if (env->ExceptionCheck()) return;
- jint len = env->GetArrayLength(src_op);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operations", n, len, type);
- return;
- }
- len = env->GetArrayLength(src_index);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operation output indices", n, len,
- type);
- return;
- }
- jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
- jint* indices = env->GetIntArrayElements(src_index, nullptr);
- for (int i = 0; i < n; ++i) {
- if (op_handles[i] == 0) {
- throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
- i, n);
- break;
- }
- dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
- static_cast<int>(indices[i])};
- }
- env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
- env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
-}
-
void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
if (buf == nullptr) return;
TF_DeleteBuffer(buf);
diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc
new file mode 100644
index 0000000000..069ac05a1c
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/java/src/main/native/utils_jni.h"
+
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n) {
+ if (env->ExceptionCheck()) return;
+ jint len = env->GetArrayLength(src_op);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operations", n, len, type);
+ return;
+ }
+ len = env->GetArrayLength(src_index);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operation output indices", n, len,
+ type);
+ return;
+ }
+ jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
+ jint* indices = env->GetIntArrayElements(src_index, nullptr);
+ for (int i = 0; i < n; ++i) {
+ if (op_handles[i] == 0) {
+ throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
+ i, n);
+ break;
+ }
+ dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
+ static_cast<int>(indices[i])};
+ }
+ env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
+ env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
+}
+
+
+
+
diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h
new file mode 100644
index 0000000000..352298e7de
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_
+#define TENSORFLOW_JAVA_UTILS_JNI_H_
+
+#include <jni.h>
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c540299bdc..c2e52c22c6 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -129,4 +130,106 @@ public class GraphTest {
// expected exception.
}
}
+
+ @Test
+ public void addGradientsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x1 = TestUtil.placeholder(g, "x1", Float.class);
+ Output<Float> x2 = TestUtil.placeholder(g, "x2", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x1);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+ Output<Float> y2 = TestUtil.addN(g, y0, x2);
+
+ Output<?>[] grads0 = g.addGradients(y1, toArray(x1));
+ assertNotNull(grads0);
+ assertEquals(1, grads0.length);
+ assertEquals(DataType.FLOAT, grads0[0].dataType());
+
+ Output<?>[] grads1 = g.addGradients(y2, toArray(x1, x2));
+ assertNotNull(grads1);
+ assertEquals(2, grads1.length);
+ assertEquals(DataType.FLOAT, grads1[0].dataType());
+ assertEquals(DataType.FLOAT, grads1[1].dataType());
+
+ try (Tensor<Float> c1 = Tensors.create(3.0f);
+ Tensor<Float> c2 = Tensors.create(2.0f);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<>(
+ s.runner()
+ .feed(x1, c1)
+ .feed(x2, c2)
+ .fetch(grads0[0])
+ .fetch(grads1[0])
+ .fetch(grads1[1])
+ .run())) {
+
+ assertEquals(3, outputs.size());
+ assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f);
+ assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f);
+ assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientSumsToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad = g.addGradients(toArray(y0, y1), toArray(x), null);
+ assertNotNull(grad);
+ assertEquals(1, grad.length);
+ assertEquals(DataType.FLOAT, grad[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(114.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void addGradientsWithInitialValuesToGraph() {
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+
+ Output<Float> x = TestUtil.placeholder(g, "x", Float.class);
+ Output<Float> y0 = TestUtil.square(g, "y0", x);
+ Output<Float> y1 = TestUtil.square(g, "y1", y0);
+
+ Output<?>[] grad0 = g.addGradients(y1, toArray(y0));
+ assertNotNull(grad0);
+ assertEquals(1, grad0.length);
+ assertEquals(DataType.FLOAT, grad0[0].dataType());
+
+ Output<?>[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0]));
+ assertNotNull(grad1);
+ assertEquals(1, grad1.length);
+ assertEquals(DataType.FLOAT, grad1[0].dataType());
+
+ try (Tensor<Float> c = Tensors.create(3.0f);
+ Tensor<?> output = s.runner()
+ .feed(x, c)
+ .fetch(grad1[0])
+ .run()
+ .get(0)) {
+
+ assertEquals(108.0f, output.floatValue(), 0.0f);
+ }
+ }
+ }
+
+ private static Output<?>[] toArray(Output<?>... outputs) {
+ return outputs;
+ }
}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
index e8cc76c2a6..7d5980bcde 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java
@@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
-import java.util.ArrayList;
-import java.util.Collection;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -36,8 +34,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}});
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed("X", x).fetch("Y").run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -53,8 +51,8 @@ public class SessionTest {
Output<Integer> feed = g.operation("X").output(0);
Output<Integer> fetch = g.operation("Y").output(0);
try (Tensor<Integer> x = Tensors.create(new int[][] {{5}, {7}});
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().feed(feed, x).fetch(fetch).run())) {
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -112,7 +110,7 @@ public class SessionTest {
.setOptions(fullTraceRunOptions())
.runAndFetchMetadata();
// Sanity check on outputs.
- AutoCloseableList<Tensor<?>> outputs = new AutoCloseableList<Tensor<?>>(result.outputs);
+ TestUtil.AutoCloseableList<Tensor<?>> outputs = new TestUtil.AutoCloseableList<Tensor<?>>(result.outputs);
assertEquals(1, outputs.size());
final int[][] expected = {{31}};
assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1]));
@@ -135,8 +133,8 @@ public class SessionTest {
Session s = new Session(g)) {
TestUtil.constant(g, "c1", 2718);
TestUtil.constant(g, "c2", 31415);
- AutoCloseableList<Tensor<?>> outputs =
- new AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
+ TestUtil.AutoCloseableList<Tensor<?>> outputs =
+ new TestUtil.AutoCloseableList<Tensor<?>>(s.runner().fetch("c2").fetch("c1").run());
assertEquals(2, outputs.size());
assertEquals(31415, outputs.get(0).intValue());
assertEquals(2718, outputs.get(1).intValue());
@@ -164,28 +162,6 @@ public class SessionTest {
Session s = new Session(g, singleThreadConfigProto())) {}
}
- private static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
- implements AutoCloseable {
- AutoCloseableList(Collection<? extends E> c) {
- super(c);
- }
-
- @Override
- public void close() {
- Exception toThrow = null;
- for (AutoCloseable c : this) {
- try {
- c.close();
- } catch (Exception e) {
- toThrow = e;
- }
- }
- if (toThrow != null) {
- throw new RuntimeException(toThrow);
- }
- }
- }
-
private static byte[] fullTraceRunOptions() {
// Ideally this would use the generated Java sources for protocol buffers
// and end up with something like the snippet below. However, generating
diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
index c973b5a3d8..4e84886416 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java
@@ -16,9 +16,34 @@ limitations under the License.
package org.tensorflow;
import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Collection;
/** Static utility functions. */
public class TestUtil {
+
+ public static final class AutoCloseableList<E extends AutoCloseable> extends ArrayList<E>
+ implements AutoCloseable {
+ AutoCloseableList(Collection<? extends E> c) {
+ super(c);
+ }
+
+ @Override
+ public void close() {
+ Exception toThrow = null;
+ for (AutoCloseable c : this) {
+ try {
+ c.close();
+ } catch (Exception e) {
+ toThrow = e;
+ }
+ }
+ if (toThrow != null) {
+ throw new RuntimeException(toThrow);
+ }
+ }
+ }
+
public static <T> Output<T> constant(Graph g, String name, Object value) {
try (Tensor<?> t = Tensor.create(value)) {
return g.opBuilder("Const", name)
@@ -36,7 +61,7 @@ public class TestUtil {
.<T>output(0);
}
- public static Output<?> addN(Graph g, Output<?>... inputs) {
+ public static <T> Output<T> addN(Graph g, Output<?>... inputs) {
return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0);
}
@@ -58,6 +83,13 @@ public class TestUtil {
.setAttr("num_split", numSplit)
.build();
}
+
+ public static <T> Output<T> square(Graph g, String name, Output<T> value) {
+ return g.opBuilder("Square", name)
+ .addInput(value)
+ .build()
+ .<T>output(0);
+ }
public static void transpose_A_times_X(Graph g, int[][] a) {
Output<Integer> aa = constant(g, "A", a);