From 73e38c29c74d9d9bf7128bf4737a410ff005611e Mon Sep 17 00:00:00 2001 From: Yifei Feng Date: Mon, 2 Jul 2018 17:07:06 -0700 Subject: Merge changes from github. PiperOrigin-RevId: 203037623 --- tensorflow/java/src/gen/cc/source_writer.cc | 1 + .../java/src/main/java/org/tensorflow/Graph.java | 79 +++++++++++ .../java/org/tensorflow/op/core/Gradients.java | 153 +++++++++++++++++++++ tensorflow/java/src/main/native/graph_jni.cc | 54 ++++++++ tensorflow/java/src/main/native/graph_jni.h | 9 ++ tensorflow/java/src/main/native/session_jni.cc | 32 +---- tensorflow/java/src/main/native/utils_jni.cc | 53 +++++++ tensorflow/java/src/main/native/utils_jni.h | 33 +++++ .../src/test/java/org/tensorflow/GraphTest.java | 103 ++++++++++++++ .../src/test/java/org/tensorflow/SessionTest.java | 38 +---- .../src/test/java/org/tensorflow/TestUtil.java | 34 ++++- 11 files changed, 526 insertions(+), 63 deletions(-) create mode 100644 tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java create mode 100644 tensorflow/java/src/main/native/utils_jni.cc create mode 100644 tensorflow/java/src/main/native/utils_jni.h (limited to 'tensorflow/java') diff --git a/tensorflow/java/src/gen/cc/source_writer.cc b/tensorflow/java/src/gen/cc/source_writer.cc index 66401bdba7..8e5fba7e32 100644 --- a/tensorflow/java/src/gen/cc/source_writer.cc +++ b/tensorflow/java/src/gen/cc/source_writer.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include diff --git a/tensorflow/java/src/main/java/org/tensorflow/Graph.java b/tensorflow/java/src/main/java/org/tensorflow/Graph.java index d4fd3db5f7..7d19696749 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Graph.java @@ -143,6 +143,82 @@ public final class Graph implements AutoCloseable { } } + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * {@code dx} are used as initial gradients (which represent the symbolic partial derivatives of some loss function + * {@code L} w.r.t. {@code y}). {@code dx} must be null or have size of {@code y}. + *

+ * If {@code dx} is null, the implementation will use dx of {@link org.tensorflow.op.core.OnesLike OnesLike} for all + * shapes in {@code y}. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param dx if not null, the partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output[] y, Output[] x, Output[] dx) { + Output[] dy = new Output[x.length]; + final long[] yHandles = new long[y.length]; + final int[] yIndices = new int[y.length]; + final long[] xHandles = new long[x.length]; + final int[] xIndices = new int[x.length]; + long[] dxHandles = null; + int[] dxIndices = null; + + try (Reference ref = ref()) { + for (int i = 0; i < y.length; ++i) { + yHandles[i] = y[i].op().getUnsafeNativeHandle(); + yIndices[i] = y[i].index(); + } + for (int i = 0; i < x.length; ++i) { + xHandles[i] = x[i].op().getUnsafeNativeHandle(); + xIndices[i] = x[i].index(); + } + if (dx != null && dx.length > 0) { + dxHandles = new long[dx.length]; + dxIndices = new int[dx.length]; + + for (int i = 0; i < dx.length; ++i) { + dxHandles[i] = dx[i].op().getUnsafeNativeHandle(); + dxIndices[i] = dx[i].index(); + } + } + // Gradient outputs are returned in two continuous arrays concatenated into one. The first holds the native handles + // of the gradient operations while the second holds the index of their output + // e.g. given xHandles = [x0Handle, x1Handle, ...] and xIndices = [x0Index, x1Index, ..], we obtain + // dy = [dy0Handle, dy1Handle, ..., dy0Index, dy1Index, ...] + long[] dyHandlesAndIndices = + addGradients(ref.nativeHandle(), yHandles, yIndices, xHandles, xIndices, dxHandles, dxIndices); + int ndy = dyHandlesAndIndices.length >> 1; + if (ndy != dy.length) { + throw new IllegalStateException(String.valueOf(ndy) + " gradients were added to the graph when " + dy.length + + " were expected"); + } + for (int i = 0, j = ndy; i < ndy; ++i, ++j) { + Operation op = new Operation(this, dyHandlesAndIndices[i]); + dy[i] = new Output<>(op, (int) dyHandlesAndIndices[j]); + } + } + return dy; + } + + /** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code dy/dx_1, dy/dx_2...} + *

+ * This is a simplified version of {@link #addGradients(Output[], Output[], Output[]) where {@code y} is + * a single output and {@code dx} is null. + * + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @return the partial derivatives {@code dy} with the size of {@code x} + */ + public Output[] addGradients(Output y, Output[] x) { + return addGradients(new Output[]{y}, x, null); + } + private final Object nativeHandleLock = new Object(); private long nativeHandle; private int refcount = 0; @@ -254,6 +330,9 @@ public final class Graph implements AutoCloseable { private static native byte[] toGraphDef(long handle); + private static native long[] addGradients(long handle, long[] inputHandles, int[] inputIndices, + long[] outputHandles, int[] outputIndices, long[] gradInputHandles, int[] gradInputIndices); + static { TensorFlow.init(); } diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java new file mode 100644 index 0000000000..f4671c8af9 --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Gradients.java @@ -0,0 +1,153 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Operands; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * Adds operations to compute the partial derivatives of sum of {@code y}s w.r.t {@code x}s, + * i.e., {@code d(y_1 + y_2 + ...)/dx_1, d(y_1 + y_2 + ...)/dx_2...} + *

+ * If {@code Options.dx()} values are set, they are as the initial symbolic partial derivatives of some loss + * function {@code L} w.r.t. {@code y}. {@code Options.dx()} must have the size of {@code y}. + *

+ * If {@code Options.dx()} is not set, the implementation will use dx of {@code OnesLike} for all + * shapes in {@code y}. + *

+ * The partial derivatives are returned in output {@code dy}, with the size of {@code x}. + *

+ * Example of usage: + *

{@code
+ * Gradients gradients = Gradients.create(scope, Arrays.asList(loss), Arrays.asList(w, b));
+ * 
+ * Constant alpha = ops.constant(1.0f, Float.class);
+ * ApplyGradientDescent.create(scope, w, alpha, gradients.dy(0));
+ * ApplyGradientDescent.create(scope, b, alpha, gradients.dy(1));
+ * }
+ */ +@Operator +public class Gradients implements Op, Iterable> { + + /** + * Optional attributes for {@link Gradients} + */ + public static class Options { + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return this option builder + */ + public Options dx(Iterable> dx) { + this.dx = dx; + return this; + } + + private Iterable> dx; + + private Options() { + } + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * @param scope current graph scope + * @param y outputs of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + public static Gradients create(Scope scope, Iterable> y, Iterable> x, Options... options) { + Output[] dx = null; + if (options != null) { + for (Options opts : options) { + if (opts.dx != null) { + dx = Operands.asOutputs(opts.dx); + } + } + } + Output[] gradOutputs = scope.graph().addGradients(Operands.asOutputs(y), Operands.asOutputs(x), dx); + return new Gradients(Arrays.asList(gradOutputs)); + } + + /** + * Adds gradients computation ops to the graph according to scope. + * + * This is a simplified version of {@link #create(Scope, Iterable, Iterable, Options...)} where {@code y} is + * a single output. + * + * @param scope current graph scope + * @param y output of the function to derive + * @param x inputs of the function for which partial derivatives are computed + * @param options carries optional attributes values + * @return a new instance of {@code Gradients} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static Gradients create(Scope scope, Operand y, Iterable> x, Options... options) { + return create(scope, (Iterable) Arrays.asList(y), x, options); + } + + /** + * @param dx partial derivatives of some loss function {@code L} w.r.t. {@code y} + * @return builder to add more options to this operation + */ + public Options dx(Iterable> dx) { + return new Options().dx(dx); + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) dy.iterator(); + } + + /** + * Partial derivatives of {@code y}s w.r.t. {@code x}s, with the size of {@code x} + */ + public List> dy() { + return dy; + } + + /** + * Returns a symbolic handle to one of the gradient operation output + *

+ * Warning: Does not check that the type of the tensor matches T. It is recommended to call + * this method with an explicit type parameter rather than letting it be inferred, e.g. {@code + * gradients.dy(0)} + * + * @param The expected element type of the tensors produced by this output. + * @param index The index of the output among the gradients added by this operation + */ + @SuppressWarnings("unchecked") + public Output dy(int index) { + return (Output) dy.get(index); + } + + private List> dy; + + private Gradients(List> dy) { + this.dy = dy; + } +} diff --git a/tensorflow/java/src/main/native/graph_jni.cc b/tensorflow/java/src/main/native/graph_jni.cc index 0fef155275..dac6a345e9 100644 --- a/tensorflow/java/src/main/native/graph_jni.cc +++ b/tensorflow/java/src/main/native/graph_jni.cc @@ -16,7 +16,9 @@ limitations under the License. #include "tensorflow/java/src/main/native/graph_jni.h" #include +#include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" namespace { @@ -130,3 +132,55 @@ Java_org_tensorflow_Graph_toGraphDef(JNIEnv* env, jclass clazz, jlong handle) { TF_DeleteBuffer(buf); return ret; } + +JNIEXPORT jlongArray JNICALL +Java_org_tensorflow_Graph_addGradients(JNIEnv* env, jclass clazz, jlong handle, + jlongArray y_handles, jintArray y_indices, + jlongArray x_handles, jintArray x_indices, + jlongArray dx_handles, jintArray dx_indices) { + + TF_Graph* g = requireHandle(env, handle); + if (g == nullptr) return nullptr; + + const jint ny = env->GetArrayLength(y_handles); + const jint nx = env->GetArrayLength(x_handles); + + std::unique_ptr y(new TF_Output[ny]); + std::unique_ptr x(new TF_Output[nx]); + std::unique_ptr dx(nullptr); + std::unique_ptr dy(new TF_Output[nx]); + + resolveOutputs(env, "y", y_handles, y_indices, y.get(), ny); + resolveOutputs(env, "x", x_handles, x_indices, x.get(), nx); + if (dx_handles != nullptr) { + if (env->GetArrayLength(dx_handles) != ny) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d dx handles", ny, + env->GetArrayLength(dx_handles)); + } + dx.reset(new TF_Output[ny]); + resolveOutputs(env, "dx", dx_handles, dx_indices, dx.get(), ny); + } + if (env->ExceptionCheck()) return nullptr; + + TF_Status* status = TF_NewStatus(); + TF_AddGradients(g, y.get(), ny, x.get(), nx, dx.get(), status, dy.get()); + + if (!throwExceptionIfNotOK(env, status)) { + TF_DeleteStatus(status); + return nullptr; + } + TF_DeleteStatus(status); + + // returned array contains both op handles and output indices, in pair + jlongArray dy_handles_and_indices = env->NewLongArray(nx << 1); + jlong* dy_elems = env->GetLongArrayElements(dy_handles_and_indices, nullptr); + for (int i = 0, j = nx; i < nx; ++i, ++j) { + TF_Output dy_output = dy.get()[i]; + dy_elems[i] = reinterpret_cast(dy_output.oper); + dy_elems[j] = static_cast(dy_output.index); + } + env->ReleaseLongArrayElements(dy_handles_and_indices, dy_elems, 0); + + return dy_handles_and_indices; +} diff --git a/tensorflow/java/src/main/native/graph_jni.h b/tensorflow/java/src/main/native/graph_jni.h index dd2e038332..4f87e8d5a7 100644 --- a/tensorflow/java/src/main/native/graph_jni.h +++ b/tensorflow/java/src/main/native/graph_jni.h @@ -73,6 +73,15 @@ JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Graph_toGraphDef(JNIEnv *, jclass, jlong); +/* + * Class: org_tensorflow_Graph + * Method: name + * Signature: (J[J[I[J[I[J[I)[J + */ +JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Graph_addGradients(JNIEnv *, + jclass, jlong, jlongArray, jintArray, jlongArray, jintArray, jlongArray, + jintArray); + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/java/src/main/native/session_jni.cc b/tensorflow/java/src/main/native/session_jni.cc index 2cd542d3c9..cb54daf137 100644 --- a/tensorflow/java/src/main/native/session_jni.cc +++ b/tensorflow/java/src/main/native/session_jni.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/c/c_api.h" +#include "tensorflow/java/src/main/native/utils_jni.h" #include "tensorflow/java/src/main/native/exception_jni.h" #include "tensorflow/java/src/main/native/session_jni.h" @@ -55,37 +56,6 @@ void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array, env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT); } -void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, - jintArray src_index, TF_Output* dst, jint n) { - if (env->ExceptionCheck()) return; - jint len = env->GetArrayLength(src_op); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operations", n, len, type); - return; - } - len = env->GetArrayLength(src_index); - if (len != n) { - throwException(env, kIllegalArgumentException, - "expected %d, got %d %s Operation output indices", n, len, - type); - return; - } - jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); - jint* indices = env->GetIntArrayElements(src_index, nullptr); - for (int i = 0; i < n; ++i) { - if (op_handles[i] == 0) { - throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, - i, n); - break; - } - dst[i] = TF_Output{reinterpret_cast(op_handles[i]), - static_cast(indices[i])}; - } - env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); - env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); -} - void TF_MaybeDeleteBuffer(TF_Buffer* buf) { if (buf == nullptr) return; TF_DeleteBuffer(buf); diff --git a/tensorflow/java/src/main/native/utils_jni.cc b/tensorflow/java/src/main/native/utils_jni.cc new file mode 100644 index 0000000000..069ac05a1c --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/java/src/main/native/utils_jni.h" + +#include "tensorflow/java/src/main/native/exception_jni.h" + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n) { + if (env->ExceptionCheck()) return; + jint len = env->GetArrayLength(src_op); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operations", n, len, type); + return; + } + len = env->GetArrayLength(src_index); + if (len != n) { + throwException(env, kIllegalArgumentException, + "expected %d, got %d %s Operation output indices", n, len, + type); + return; + } + jlong* op_handles = env->GetLongArrayElements(src_op, nullptr); + jint* indices = env->GetIntArrayElements(src_index, nullptr); + for (int i = 0; i < n; ++i) { + if (op_handles[i] == 0) { + throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type, + i, n); + break; + } + dst[i] = TF_Output{reinterpret_cast(op_handles[i]), + static_cast(indices[i])}; + } + env->ReleaseIntArrayElements(src_index, indices, JNI_ABORT); + env->ReleaseLongArrayElements(src_op, op_handles, JNI_ABORT); +} + + + + diff --git a/tensorflow/java/src/main/native/utils_jni.h b/tensorflow/java/src/main/native/utils_jni.h new file mode 100644 index 0000000000..352298e7de --- /dev/null +++ b/tensorflow/java/src/main/native/utils_jni.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_JAVA_UTILS_JNI_H_ +#define TENSORFLOW_JAVA_UTILS_JNI_H_ + +#include + +#include "tensorflow/c/c_api.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +void resolveOutputs(JNIEnv* env, const char* type, jlongArray src_op, + jintArray src_index, TF_Output* dst, jint n); + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif /* TENSORFLOW_JAVA_UTILS_JNI_H_ */ diff --git a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java index c540299bdc..c2e52c22c6 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/GraphTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertTrue; import java.util.HashSet; import java.util.Iterator; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -129,4 +130,106 @@ public class GraphTest { // expected exception. } } + + @Test + public void addGradientsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x1 = TestUtil.placeholder(g, "x1", Float.class); + Output x2 = TestUtil.placeholder(g, "x2", Float.class); + Output y0 = TestUtil.square(g, "y0", x1); + Output y1 = TestUtil.square(g, "y1", y0); + Output y2 = TestUtil.addN(g, y0, x2); + + Output[] grads0 = g.addGradients(y1, toArray(x1)); + assertNotNull(grads0); + assertEquals(1, grads0.length); + assertEquals(DataType.FLOAT, grads0[0].dataType()); + + Output[] grads1 = g.addGradients(y2, toArray(x1, x2)); + assertNotNull(grads1); + assertEquals(2, grads1.length); + assertEquals(DataType.FLOAT, grads1[0].dataType()); + assertEquals(DataType.FLOAT, grads1[1].dataType()); + + try (Tensor c1 = Tensors.create(3.0f); + Tensor c2 = Tensors.create(2.0f); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList<>( + s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run())) { + + assertEquals(3, outputs.size()); + assertEquals(108.0f, outputs.get(0).floatValue(), 0.0f); + assertEquals(6.0f, outputs.get(1).floatValue(), 0.0f); + assertEquals(1.0f, outputs.get(2).floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientSumsToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad = g.addGradients(toArray(y0, y1), toArray(x), null); + assertNotNull(grad); + assertEquals(1, grad.length); + assertEquals(DataType.FLOAT, grad[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad[0]) + .run() + .get(0)) { + + assertEquals(114.0f, output.floatValue(), 0.0f); + } + } + } + + @Test + public void addGradientsWithInitialValuesToGraph() { + try (Graph g = new Graph(); + Session s = new Session(g)) { + + Output x = TestUtil.placeholder(g, "x", Float.class); + Output y0 = TestUtil.square(g, "y0", x); + Output y1 = TestUtil.square(g, "y1", y0); + + Output[] grad0 = g.addGradients(y1, toArray(y0)); + assertNotNull(grad0); + assertEquals(1, grad0.length); + assertEquals(DataType.FLOAT, grad0[0].dataType()); + + Output[] grad1 = g.addGradients(toArray(y0), toArray(x), toArray(grad0[0])); + assertNotNull(grad1); + assertEquals(1, grad1.length); + assertEquals(DataType.FLOAT, grad1[0].dataType()); + + try (Tensor c = Tensors.create(3.0f); + Tensor output = s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run() + .get(0)) { + + assertEquals(108.0f, output.floatValue(), 0.0f); + } + } + } + + private static Output[] toArray(Output... outputs) { + return outputs; + } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java index e8cc76c2a6..7d5980bcde 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/SessionTest.java @@ -20,8 +20,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.util.ArrayList; -import java.util.Collection; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,8 +34,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.transpose_A_times_X(g, new int[][] {{2}, {3}}); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -53,8 +51,8 @@ public class SessionTest { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (Tensor x = Tensors.create(new int[][] {{5}, {7}}); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -112,7 +110,7 @@ public class SessionTest { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList> outputs = new AutoCloseableList>(result.outputs); + TestUtil.AutoCloseableList> outputs = new TestUtil.AutoCloseableList>(result.outputs); assertEquals(1, outputs.size()); final int[][] expected = {{31}}; assertArrayEquals(expected, outputs.get(0).copyTo(new int[1][1])); @@ -135,8 +133,8 @@ public class SessionTest { Session s = new Session(g)) { TestUtil.constant(g, "c1", 2718); TestUtil.constant(g, "c2", 31415); - AutoCloseableList> outputs = - new AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); + TestUtil.AutoCloseableList> outputs = + new TestUtil.AutoCloseableList>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); assertEquals(31415, outputs.get(0).intValue()); assertEquals(2718, outputs.get(1).intValue()); @@ -164,28 +162,6 @@ public class SessionTest { Session s = new Session(g, singleThreadConfigProto())) {} } - private static final class AutoCloseableList extends ArrayList - implements AutoCloseable { - AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } - } - private static byte[] fullTraceRunOptions() { // Ideally this would use the generated Java sources for protocol buffers // and end up with something like the snippet below. However, generating diff --git a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java index c973b5a3d8..4e84886416 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TestUtil.java @@ -16,9 +16,34 @@ limitations under the License. package org.tensorflow; import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.Collection; /** Static utility functions. */ public class TestUtil { + + public static final class AutoCloseableList extends ArrayList + implements AutoCloseable { + AutoCloseableList(Collection c) { + super(c); + } + + @Override + public void close() { + Exception toThrow = null; + for (AutoCloseable c : this) { + try { + c.close(); + } catch (Exception e) { + toThrow = e; + } + } + if (toThrow != null) { + throw new RuntimeException(toThrow); + } + } + } + public static Output constant(Graph g, String name, Object value) { try (Tensor t = Tensor.create(value)) { return g.opBuilder("Const", name) @@ -36,7 +61,7 @@ public class TestUtil { .output(0); } - public static Output addN(Graph g, Output... inputs) { + public static Output addN(Graph g, Output... inputs) { return g.opBuilder("AddN", "AddN").addInputList(inputs).build().output(0); } @@ -58,6 +83,13 @@ public class TestUtil { .setAttr("num_split", numSplit) .build(); } + + public static Output square(Graph g, String name, Output value) { + return g.opBuilder("Square", name) + .addInput(value) + .build() + .output(0); + } public static void transpose_A_times_X(Graph g, int[][] a) { Output aa = constant(g, "A", a); -- cgit v1.2.3