aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
authorGravatar karl@kubx.ca <karl@kubx.ca>2018-06-14 21:08:59 -0400
committerGravatar karl@kubx.ca <karl@kubx.ca>2018-06-27 21:47:59 -0400
commitfac56f9c9ab58fe7406a826683559de4cef85637 (patch)
treeb11c2c509e9fe4f6adf876603f182c381f74f239 /tensorflow/java
parent6929e91e74a33a3a7c7ce309af923d47c6ff81a5 (diff)
Support addition of gradient operations in a graph
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Graph.java65
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java137
-rw-r--r--tensorflow/java/src/main/native/graph_jni.cc54
-rw-r--r--tensorflow/java/src/main/native/graph_jni.h9
-rw-r--r--tensorflow/java/src/main/native/session_jni.cc32
-rw-r--r--tensorflow/java/src/main/native/utils_jni.cc53
-rw-r--r--tensorflow/java/src/main/native/utils_jni.h33
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/GraphTest.java21
8 files changed, 373 insertions, 31 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
index d4fd3db5f7..92ab4ef4d7 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java
@@ -143,6 +143,68 @@ public final class Graph implements AutoCloseable {
}
}
+ /**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * {@code dx} are used as initial gradients (which represent the symbolic partial
+ * derivatives of some loss function {@code L} w.r.t. {@code y}).
+ * {@code dx} must be null or have size of {@code y}.
+ * <p>
+ * If {@code dx} is null, the implementation will use dx of {@code OnesLike} for all
+ * shapes in {@code y}.
+ *
+ * @param y
+ * @param x
+ * @param dx
+ * @return the partial derivatives {@code dy} with the size of {@code x}
+ */
+ public Output<?>[] addGradients(Output<?>[] y, Output<?>[] x, Output<?>[] dx) {
+ final long[] yHandles = new long[y.length];
+ final int[] yIndices = new int[y.length];
+ final long[] xHandles = new long[x.length];
+ final int[] xIndices = new int[x.length];
+ long[] dxHandles = null;
+ int[] dxIndices = null;
+
+ for (int i = 0; i < y.length; ++i) {
+ yHandles[i] = y[i].op().getUnsafeNativeHandle();
+ yIndices[i] = y[i].index();
+ }
+ for (int i = 0; i < x.length; ++i) {
+ xHandles[i] = x[i].op().getUnsafeNativeHandle();
+ xIndices[i] = x[i].index();
+ }
+ if (dx != null && dx.length > 0) {
+ dxHandles = new long[dx.length];
+ dxIndices = new int[dx.length];
+
+ for (int i = 0; i < dx.length; ++i) {
+ dxHandles[i] = dx[i].op().getUnsafeNativeHandle();
+ dxIndices[i] = dx[i].index();
+ }
+ }
+ // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles
+ // of the gradient operations while the second holds the index of their output
+ // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain
+ // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...]
+ long[] dyHandlesAndIndices;
+ synchronized (nativeHandleLock) {
+ dyHandlesAndIndices = addGradients(nativeHandle, yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices);
+ }
+ int ndy = dyHandlesAndIndices.length >> 1;
+ if (ndy != x.length) {
+ throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + x.length
+ + " were expected");
+ }
+ Output<?>[] dy = new Output<?>[ndy];
+ for (int i = 0, j = ndy; i < ndy; ++i, ++j) {
+ Operation op = new Operation(this, dyHandlesAndIndices[i]);
+ dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]);
+ }
+ return dy;
+ }
+
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int refcount = 0;
@@ -254,6 +316,9 @@ public final class Graph implements AutoCloseable {
private static native byte[] toGraphDef(long handle);
+ private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices,
+ long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices);
+
static {
TensorFlow.init();
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java
new file mode 100644
index 0000000000..2db34bf188
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/training/AddGradients.java
@@ -0,0 +1,137 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.training;
+
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import org.tensorflow.Operand;
+import org.tensorflow.Output;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Operands;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/**
+ * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s,
+ * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...}
+ * <p>
+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss
+ * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}.
+ * <p>
+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all
+ * shapes in {@code y}.
+ * <p>
+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}.
+ * <p>
+ * Example of usage:
+ * <pre>{@code
+ * AddGradients gradients = AddGradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ *
+ * Constant<Float> alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.<Float>dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.<Float>dy(1));
+ * }</pre>
+ */
+@Operator
+public class AddGradients implements Op, Iterable<Operand<?>> {
+
+ /**
+ * Optional attributes for {@link AddGradients}
+ */
+ public static class Options {
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return this option builder
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ this.dx = dx;
+ return this;
+ }
+
+ private Iterable<Operand<?>> dx;
+
+ private Options() {
+ }
+ }
+
+ /**
+ * Adds gradients computation ops to the graph according to scope.
+ *
+ * @param scope current graph scope
+ * @param y
+ * @param x
+ * @param dx
+ * @return a new instance of {@code AddGradients}
+ */
+ public static AddGradients create(Scope scope, Iterable<Operand<?>> y, Iterable<Operand<?>> x, Options... options) {
+ Output<?>[] dx = null;
+ if (options != null) {
+ for (Options opts : options) {
+ if (opts.dx != null) {
+ dx = Operands.asOutputs(opts.dx);
+ }
+ }
+ }
+ Output<?>[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx);
+ return new AddGradients(Arrays.asList(gradOutputs));
+ }
+
+ /**
+ * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y}
+ * @return builder to add more options to this operation
+ */
+ public Options dx(Iterable<Operand<?>> dx) {
+ return new Options().dx(dx);
+ }
+
+ @Override
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ public Iterator<Operand<?>> iterator() {
+ return (Iterator) dy.iterator();
+ }
+
+ /**
+ * {@code dy} of size {@code x}, i.e. the outputs of the operations added to the graph to compute gradients for each
+ * {@code x} nodes respectively.
+ */
+ public List<Output<?>> dy() {
+ return dy;
+ }
+
+ /**
+ * Returns a symbolic handle to one of the gradient operation output
+ * <p>
+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call
+ * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code
+ * gradients.<Integer>dy(0)}
+ *
+ * @param <T> The expected element type of the tensors produced by this output.
+ * @param index The index of the output among the gradients added by this operation
+ */
+ @SuppressWarnings("unchecked")
+ public <T> Output<T> dy(int index) {
+ return (Output<T>) dy.get(index);
+ }
+
+ private List<Output<?>> dy;
+
+ private AddGradients(List<Output<?>> dy) {
+ this.dy = dy;
+ }
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc
index 0fef155275..dac6a345e9 100644
--- a/tensorflow/java/src/main/native/graph_jni.cc
+++ b/tensorflow/java/src/main/native/graph_jni.cc
@@ -16,7 +16,9 @@ limitations under the License.
#include "tensorflow/java/src/main/native/graph_jni.h"
#include <limits>
+#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
namespace {
@@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) {
TF_DeleteBuffer(buf);
return ret;
}
+
+JNIEXPORT jlongArray JNICALL
+Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle,
+ jlongArray y_handles, jintArray y_indices,
+ jlongArray x_handles, jintArray x_indices,
+ jlongArray dx_handles, jintArray dx_indices) {
+
+ TF_Graph* g = requireHandle(env, handle);
+ if (g == nullptr) return nullptr;
+
+ const jint ny = env->GetArrayLength(y_handles);
+ const jint nx = env->GetArrayLength(x_handles);
+
+ std::unique_ptr<TF_Output[]> y(new TF_Output[ny]);
+ std::unique_ptr<TF_Output[]> x(new TF_Output[nx]);
+ std::unique_ptr<TF_Output[]> dx(nullptr);
+ std::unique_ptr<TF_Output[]> dy(new TF_Output[nx]);
+
+ resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny);
+ resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx);
+ if (dx_handles != nullptr) {
+ if (env->GetArrayLength(dx_handles) != ny) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d dx handles", ny,
+ env->GetArrayLength(dx_handles));
+ }
+ dx.reset(new TF_Output[ny]);
+ resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny);
+ }
+ if (env->ExceptionCheck()) return nullptr;
+
+ TF_Status* status = TF_NewStatus();
+ TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get());
+
+ if (!throwExceptionIfNotOK(env, status)) {
+ TF_DeleteStatus(status);
+ return nullptr;
+ }
+ TF_DeleteStatus(status);
+
+ // returned array contains both op handles and output indices, in pair
+ jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1);
+ jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr);
+ for (int i = 0, j = nx; i < nx; ++i, ++j) {
+ TF_Output dy_output = dy.get()[i];
+ dy_elems[i] = reinterpret_cast<jlong>(dy_output.oper);
+ dy_elems[j] = static_cast<jlong>(dy_output.index);
+ }
+ env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0);
+
+ return dy_handles_and_indices;
+}
diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h
index dd2e038332..4f87e8d5a7 100644
--- a/tensorflow/java/src/main/native/graph_jni.h
+++ b/tensorflow/java/src/main/native/graph_jni.h
@@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *,
jclass,
jlong);
+/*
+ * Class: org_tensorflow_Graph
+ * Method: name
+ * Signature: (J[J[I[J[I[J[I)[J
+ */
+JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *,
+ jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray,
+ jintArray);
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc
index 2cd542d3c9..cb54daf137 100644
--- a/tensorflow/java/src/main/native/session_jni.cc
+++ b/tensorflow/java/src/main/native/session_jni.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/c/c_api.h"
+#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/session_jni.h"
@@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
}
-void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
- jintArray src_index, TF_Output* dst, jint n) {
- if (env->ExceptionCheck()) return;
- jint len = env->GetArrayLength(src_op);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operations", n, len, type);
- return;
- }
- len = env->GetArrayLength(src_index);
- if (len != n) {
- throwException(env, kIllegalArgumentException,
- "expected %d, got %d %s Operation output indices", n, len,
- type);
- return;
- }
- jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
- jint* indices = env->GetIntArrayElements(src_index, nullptr);
- for (int i = 0; i < n; ++i) {
- if (op_handles[i] == 0) {
- throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
- i, n);
- break;
- }
- dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
- static_cast<int>(indices[i])};
- }
- env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
- env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
-}
-
void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
if (buf == nullptr) return;
TF_DeleteBuffer(buf);
diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc
new file mode 100644
index 0000000000..069ac05a1c
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/java/src/main/native/utils_jni.h"
+
+#include "tensorflow/java/src/main/native/exception_jni.h"
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n) {
+ if (env->ExceptionCheck()) return;
+ jint len = env->GetArrayLength(src_op);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operations", n, len, type);
+ return;
+ }
+ len = env->GetArrayLength(src_index);
+ if (len != n) {
+ throwException(env, kIllegalArgumentException,
+ "expected %d, got %d %s Operation output indices", n, len,
+ type);
+ return;
+ }
+ jlong* op_handles = env->GetLongArrayElements(src_op, nullptr);
+ jint* indices = env->GetIntArrayElements(src_index, nullptr);
+ for (int i = 0; i < n; ++i) {
+ if (op_handles[i] == 0) {
+ throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
+ i, n);
+ break;
+ }
+ dst[i] = TF_Output{reinterpret_cast<TF_Operation*>(op_handles[i]),
+ static_cast<int>(indices[i])};
+ }
+ env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT);
+ env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT);
+}
+
+
+
+
diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h
new file mode 100644
index 0000000000..352298e7de
--- /dev/null
+++ b/tensorflow/java/src/main/native/utils_jni.h
@@ -0,0 +1,33 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_
+#define TENSORFLOW_JAVA_UTILS_JNI_H_
+
+#include <jni.h>
+
+#include "tensorflow/c/c_api.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif // __cplusplus
+
+void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op,
+ jintArray src_index, TF_Output* dst, jint n);
+
+#ifdef __cplusplus
+} // extern "C"
+#endif // __cplusplus
+#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */
diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
index c540299bdc..aa6e5f0235 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java
@@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue;
import java.util.HashSet;
import java.util.Iterator;
+
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
@@ -129,4 +130,24 @@ public class GraphTest {
// expected exception.
}
}
+
+ @Test
+ public void addGradientsComputationOpsToGraph() {
+ try (Graph g = new Graph()) {
+ Output<Integer> a = TestUtil.constant(g, "A", new int[][] {{1},{2}});
+ Output<Integer> b = TestUtil.placeholder(g, "B", Integer.class);
+ Output<Integer> c = TestUtil.placeholder(g, "C", Integer.class);
+ Output<Integer> ab = TestUtil.matmul(g, "AxB", a, b, false, false);
+ Output<Integer> abc = TestUtil.matmul(g, "AxBxC", ab, c, false, false);
+
+ Output<?>[] grad = g.addGradients(new Output<?>[] {abc}, new Output<?>[] {b, c}, null);
+
+ assertNotNull(grad);
+ assertEquals(2, grad.length);
+ assertNotNull(grad[0]);
+ assertEquals(DataType.INT32, grad[0].dataType());
+ assertNotNull(grad[1]);
+ assertEquals(DataType.INT32, grad[1].dataType());
+ }
+ }
}