aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java')
-rw-r--r--tensorflow/java/BUILD15
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java112
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java173
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java131
4 files changed, 430 insertions, 1 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD
index 9fb4821cb1..64b3767735 100644
--- a/tensorflow/java/BUILD
+++ b/tensorflow/java/BUILD
@@ -34,7 +34,7 @@ filegroup(
filegroup(
name = "java_op_sources",
- srcs = glob(["src/main/java/org/tensorflow/op/*.java"]),
+ srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]),
visibility = [
"//tensorflow/java:__pkg__",
],
@@ -191,6 +191,19 @@ java_test(
],
)
+java_test(
+ name = "ConstantTest",
+ size = "small",
+ srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"],
+ javacopts = JAVACOPTS,
+ test_class = "org.tensorflow.op.core.ConstantTest",
+ deps = [
+ ":tensorflow",
+ ":testutil",
+ "@junit",
+ ],
+)
+
filegroup(
name = "libtensorflow_jni",
srcs = select({
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java
new file mode 100644
index 0000000000..59476fb43d
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java
@@ -0,0 +1,112 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.annotation;
+
+import java.lang.annotation.Documented;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
+/**
+ * Annotation used by classes to make TensorFlow operations conveniently accessible via {@code
+ * org.tensorflow.op.Ops}.
+ *
+ * <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by
+ * aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at
+ * least one public static factory method named {@code create} that accepts a {@link
+ * org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in
+ * the {@code Ops} class. For example:
+ *
+ * <pre>{@code
+ * @Operator
+ * public final class MyOp implements Op {
+ * public static MyOp create(Scope scope, Operand operand) {
+ * ...
+ * }
+ * }
+ * }</pre>
+ *
+ * <p>results in a method in the {@code Ops} class
+ *
+ * <pre>{@code
+ * import org.tensorflow.op.Ops;
+ * ...
+ * Ops ops = new Ops(graph);
+ * ...
+ * ops.myOp(operand);
+ * // and has exactly the same effect as calling
+ * // MyOp.create(ops.getScope(), operand);
+ * }</pre>
+ */
+@Documented
+@Target(ElementType.TYPE)
+@Retention(RetentionPolicy.CLASS)
+public @interface Operator {
+ /**
+ * Specify an optional group within the {@code Ops} class.
+ *
+ * <p>By default, an annotation processor will create convenience methods directly in the {@code
+ * Ops} class. An annotated operator may optionally choose to place the method within a group. For
+ * example:
+ *
+ * <pre>{@code
+ * @Operator(group="math")
+ * public final class Add extends PrimitiveOp implements Operand {
+ * ...
+ * }
+ * }</pre>
+ *
+ * <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops}
+ * class.
+ *
+ * <pre>{@code
+ * ops.math().add(...);
+ * }</pre>
+ *
+ * <p>The group name must be a <a
+ * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
+ * identifier</a>.
+ */
+ String group() default "";
+
+ /**
+ * Name for the wrapper method used in the {@code Ops} class.
+ *
+ * <p>By default, a processor derives the method name in the {@code Ops} class from the class name
+ * of the operator. This attribute allow you to provide a different name instead. For example:
+ *
+ * <pre>{@code
+ * @Operator(name="myOperation")
+ * public final class MyRealOperation implements Operand {
+ * public static MyRealOperation create(...)
+ * }
+ * }</pre>
+ *
+ * <p>results in this method added to the {@code Ops} class
+ *
+ * <pre>{@code
+ * ops.myOperation(...);
+ * // and is the same as calling
+ * // MyRealOperation.create(...)
+ * }</pre>
+ *
+ * <p>The name must be a <a
+ * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java
+ * identifier</a>.
+ */
+ String name() default "";
+}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
new file mode 100644
index 0000000000..cd7931d3bb
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java
@@ -0,0 +1,173 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import org.tensorflow.DataType;
+import org.tensorflow.Operand;
+import org.tensorflow.Operation;
+import org.tensorflow.Output;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.PrimitiveOp;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Operator;
+
+/** An operator producing a constant value. */
+@Operator
+public final class Constant extends PrimitiveOp implements Operand {
+ /**
+ * Create a constant from a Java object.
+ *
+ * <p>The argument {@code object} is first converted into a Tensor using {@link
+ * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be
+ * provided. For example:
+ *
+ * <pre>{@code
+ * Constant.create(scope, 7); // returns a constant scalar tensor 7
+ * }</pre>
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param object a Java object representing the constant.
+ * @see org.tensorflow.Tensor#create(Object) Tensor.create
+ */
+ public static Constant create(Scope scope, Object object) {
+ try (Tensor value = Tensor.create(object)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Create a {@link DataType#INT32} constant with data from the given buffer.
+ *
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
+ */
+ public static Constant create(Scope scope, long[] shape, IntBuffer data) {
+ try (Tensor value = Tensor.create(shape, data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Create a {@link DataType#FLOAT} constant with data from the given buffer.
+ *
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
+ */
+ public static Constant create(Scope scope, long[] shape, FloatBuffer data) {
+ try (Tensor value = Tensor.create(shape, data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Create a {@link DataType#DOUBLE} constant with data from the given buffer.
+ *
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
+ */
+ public static Constant create(Scope scope, long[] shape, DoubleBuffer data) {
+ try (Tensor value = Tensor.create(shape, data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Create a {@link DataType#INT64} constant with data from the given buffer.
+ *
+ * <p>Creates a constant with the given shape by copying elements from the buffer (starting from
+ * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents
+ * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
+ * method.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
+ */
+ public static Constant create(Scope scope, long[] shape, LongBuffer data) {
+ try (Tensor value = Tensor.create(shape, data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ /**
+ * Create a constant with data from the given buffer.
+ *
+ * <p>Creates a Constant with the provided shape of any type where the constant data has been
+ * encoded into {@code data} as per the specification of the TensorFlow <a
+ * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>.
+ *
+ * @param scope is a scope used to add the underlying operation.
+ * @param dataType the tensor datatype.
+ * @param shape the tensor shape.
+ * @param data a buffer containing the tensor data.
+ * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
+ * buffer
+ */
+ public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) {
+ try (Tensor value = Tensor.create(dataType, shape, data)) {
+ return createWithTensor(scope, value);
+ }
+ }
+
+ private static Constant createWithTensor(Scope scope, Tensor value) {
+ return new Constant(
+ scope
+ .graph()
+ .opBuilder("Const", scope.makeOpName("Const"))
+ .setAttr("value", value)
+ .setAttr("dtype", value.dataType())
+ .build());
+ }
+
+ @Override
+ public Output asOutput() {
+ return output;
+ }
+
+ private Constant(Operation operation) {
+ super(operation);
+ output = operation.output(0);
+ }
+
+ private final Output output;
+}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
new file mode 100644
index 0000000000..ec23792485
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -0,0 +1,131 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.DataType;
+import org.tensorflow.Graph;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class ConstantTest {
+ private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createIntBuffer() {
+ int[] ints = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, result.copyTo(actual));
+ }
+ }
+
+ @Test
+ public void createFloatBuffer() {
+ float[] floats = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, result.copyTo(actual), EPSILON);
+ }
+ }
+
+ @Test
+ public void createDoubleBuffer() {
+ double[] doubles = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
+ }
+ }
+
+ @Test
+ public void createLongBuffer() {
+ long[] longs = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, result.copyTo(actual));
+ }
+ }
+
+ @Test
+ public void createStringBuffer() throws IOException {
+
+ byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
+ long[] shape = {};
+
+ // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
+ // followed by a varint encoded size, followed by the data.
+ ByteArrayOutputStream baout = new ByteArrayOutputStream();
+ DataOutputStream out = new DataOutputStream(baout);
+ // Offset in array.
+ out.writeLong(0L);
+ // Varint encoded length of buffer.
+ // For any number < 0x80, the varint encoding is simply the number itself.
+ // https://developers.google.com/protocol-buffers/docs/encoding#varints
+ assertTrue(data.length < 0x80);
+ out.write(data.length);
+ out.write(data);
+ out.close();
+ byte[] content = baout.toByteArray();
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ assertArrayEquals(data, result.bytesValue());
+ }
+ }
+}