diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java | 112 | ||||
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java | 173 |
2 files changed, 285 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java new file mode 100644 index 0000000000..59476fb43d --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/annotation/Operator.java @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.annotation; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation used by classes to make TensorFlow operations conveniently accessible via {@code + * org.tensorflow.op.Ops}. + * + * <p>An annotation processor (TODO: not yet implemented) builds the {@code Ops} class by + * aggregating all classes annotated as {@code @Operator}s. Each annotated class <b>must</b> have at + * least one public static factory method named {@code create} that accepts a {@link + * org.tensorflow.op.Scope} as its first argument. The processor then adds a convenience method in + * the {@code Ops} class. For example: + * + * <pre>{@code + * @Operator + * public final class MyOp implements Op { + * public static MyOp create(Scope scope, Operand operand) { + * ... + * } + * } + * }</pre> + * + * <p>results in a method in the {@code Ops} class + * + * <pre>{@code + * import org.tensorflow.op.Ops; + * ... + * Ops ops = new Ops(graph); + * ... + * ops.myOp(operand); + * // and has exactly the same effect as calling + * // MyOp.create(ops.getScope(), operand); + * }</pre> + */ +@Documented +@Target(ElementType.TYPE) +@Retention(RetentionPolicy.CLASS) +public @interface Operator { + /** + * Specify an optional group within the {@code Ops} class. + * + * <p>By default, an annotation processor will create convenience methods directly in the {@code + * Ops} class. An annotated operator may optionally choose to place the method within a group. For + * example: + * + * <pre>{@code + * @Operator(group="math") + * public final class Add extends PrimitiveOp implements Operand { + * ... + * } + * }</pre> + * + * <p>results in the {@code add} method placed within a {@code math} group within the {@code Ops} + * class. + * + * <pre>{@code + * ops.math().add(...); + * }</pre> + * + * <p>The group name must be a <a + * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java + * identifier</a>. + */ + String group() default ""; + + /** + * Name for the wrapper method used in the {@code Ops} class. + * + * <p>By default, a processor derives the method name in the {@code Ops} class from the class name + * of the operator. This attribute allow you to provide a different name instead. For example: + * + * <pre>{@code + * @Operator(name="myOperation") + * public final class MyRealOperation implements Operand { + * public static MyRealOperation create(...) + * } + * }</pre> + * + * <p>results in this method added to the {@code Ops} class + * + * <pre>{@code + * ops.myOperation(...); + * // and is the same as calling + * // MyRealOperation.create(...) + * }</pre> + * + * <p>The name must be a <a + * href="https://docs.oracle.com/javase/specs/jls/se7/html/jls-3.html#jls-3.8">valid Java + * identifier</a>. + */ + String name() default ""; +} diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java new file mode 100644 index 0000000000..cd7931d3bb --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.op.PrimitiveOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** An operator producing a constant value. */ +@Operator +public final class Constant extends PrimitiveOp implements Operand { + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, 7); // returns a constant scalar tensor 7 + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object) { + try (Tensor value = Tensor.create(object)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT32} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT64} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a constant with data from the given buffer. + * + * <p>Creates a Constant with the provided shape of any type where the constant data has been + * encoded into {@code data} as per the specification of the TensorFlow <a + * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>. + * + * @param scope is a scope used to add the underlying operation. + * @param dataType the tensor datatype. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the + * buffer + */ + public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) { + try (Tensor value = Tensor.create(dataType, shape, data)) { + return createWithTensor(scope, value); + } + } + + private static Constant createWithTensor(Scope scope, Tensor value) { + return new Constant( + scope + .graph() + .opBuilder("Const", scope.makeOpName("Const")) + .setAttr("value", value) + .setAttr("dtype", value.dataType()) + .build()); + } + + @Override + public Output asOutput() { + return output; + } + + private Constant(Operation operation) { + super(operation); + output = operation.output(0); + } + + private final Output output; +} |