diff options
Diffstat (limited to 'tensorflow/java')
4 files changed, 430 insertions, 1 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 9fb4821cb1..64b3767735 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -34,7 +34,7 @@ filegroup( filegroup( name = "java_op_sources", - srcs = glob(["src/main/java/org/tensorflow/op/*.java"]), + srcs = glob(["src/main/java/org/tensorflow/op/**/*.java"]), visibility = [ "//tensorflow/java:__pkg__", ], @@ -191,6 +191,19 @@ java_test( ], ) +java_test( + name = "ConstantTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/ConstantTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.ConstantTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "libtensorflow_jni", srcs = select({ diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java new file mode 100644 index 0000000000..59476fb43d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation used by classes to make TensorFlow operations conveniently accessible via {@code + * org.tensorflow.op.Ops}. + * + * <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by + * aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at + * least one public static factory method named {@code create} that accepts a {@link + * org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in + * the {@code Ops} class. For example: + * + * <pre>{@code + * @Operator + * public final class MyOp implements Op { + * public static MyOp create(Scope scope, Operand operand) { + * ... + * } + * } + * }</pre> + * + * <p>results in a method in the {@code Ops} class + * + * <pre>{@code + * import org.tensorflow.op.Ops; + * ... + * Ops ops = new Ops(graph); + * ... + * ops.myOp(operand); + * // and has exactly the same effect as calling + * // MyOp.create(ops.getScope(), operand); + * }</pre> + */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.CLASS) +public @interface Operator { + /** + * Specify an optional group within the {@code Ops} class. + * + * <p>By default, an annotation processor will create convenience methods directly in the {@code + * Ops} class. An annotated operator may optionally choose to place the method within a group. For + * example: + * + * <pre>{@code + * @Operator(group="math") + * public final class Add extends PrimitiveOp implements Operand { + * ... + * } + * }</pre> + * + * <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops} + * class. + * + * <pre>{@code + * ops.math().add(...); + * }</pre> + * + * <p>The group name must be a <a + * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java + * identifier</a>. + */ + String group() default ""; + + /** + * Name for the wrapper method used in the {@code Ops} class. + * + * <p>By default, a processor derives the method name in the {@code Ops} class from the class name + * of the operator. This attribute allow you to provide a different name instead. For example: + * + * <pre>{@code + * @Operator(name="myOperation") + * public final class MyRealOperation implements Operand { + * public static MyRealOperation create(...) + * } + * }</pre> + * + * <p>results in this method added to the {@code Ops} class + * + * <pre>{@code + * ops.myOperation(...); + * // and is the same as calling + * // MyRealOperation.create(...) + * }</pre> + * + * <p>The name must be a <a + * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java + * identifier</a>. + */ + String name() default ""; +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java new file mode 100644 index 0000000000..cd7931d3bb --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.op.PrimitiveOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** An operator producing a constant value. */ +@Operator +public final class Constant extends PrimitiveOp implements Operand { + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, 7); // returns a constant scalar tensor 7 + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object) { + try (Tensor value = Tensor.create(object)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT32} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT64} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a constant with data from the given buffer. + * + * <p>Creates a Constant with the provided shape of any type where the constant data has been + * encoded into {@code data} as per the specification of the TensorFlow <a + * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>. + * + * @param scope is a scope used to add the underlying operation. + * @param dataType the tensor datatype. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the + * buffer + */ + public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) { + try (Tensor value = Tensor.create(dataType, shape, data)) { + return createWithTensor(scope, value); + } + } + + private static Constant createWithTensor(Scope scope, Tensor value) { + return new Constant( + scope + .graph() + .opBuilder("Const", scope.makeOpName("Const")) + .setAttr("value", value) + .setAttr("dtype", value.dataType()) + .build()); + } + + @Override + public Output asOutput() { + return output; + } + + private Constant(Operation operation) { + super(operation); + output = operation.output(0); + } + + private final Output output; +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java new file mode 100644 index 0000000000..ec23792485 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -0,0 +1,131 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertTrue; + +import java.io.ByteArrayOutputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.DataType; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; + +@RunWith(JUnit4.class) +public class ConstantTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntBuffer() { + int[] ints = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.copyTo(actual)); + } + } + + @Test + public void createFloatBuffer() { + float[] floats = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.copyTo(actual), EPSILON); + } + } + + @Test + public void createDoubleBuffer() { + double[] doubles = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + } + } + + @Test + public void createLongBuffer() { + long[] longs = {1, 2, 3, 4}; + long[] shape = {4}; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.copyTo(actual)); + } + } + + @Test + public void createStringBuffer() throws IOException { + + byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; + long[] shape = {}; + + // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer, + // followed by a varint encoded size, followed by the data. + ByteArrayOutputStream baout = new ByteArrayOutputStream(); + DataOutputStream out = new DataOutputStream(baout); + // Offset in array. + out.writeLong(0L); + // Varint encoded length of buffer. + // For any number < 0x80, the varint encoding is simply the number itself. + // https://developers.google.com/protocol-buffers/docs/encoding#varints + assertTrue(data.length < 0x80); + out.write(data.length); + out.write(data); + out.close(); + byte[] content = baout.toByteArray(); + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content)); + Tensor result = sess.runner().fetch(op.asOutput()).run().get(0); + assertArrayEquals(data, result.bytesValue()); + } + } +} |