diff options
Diffstat (limited to 'tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java')
-rw-r--r-- | tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java | 173 |
1 files changed, 173 insertions, 0 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java new file mode 100644 index 0000000000..cd7931d3bb --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -0,0 +1,173 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.Output; +import org.tensorflow.Tensor; +import org.tensorflow.op.PrimitiveOp; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** An operator producing a constant value. */ +@Operator +public final class Constant extends PrimitiveOp implements Operand { + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, 7); // returns a constant scalar tensor 7 + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static Constant create(Scope scope, Object object) { + try (Tensor value = Tensor.create(object)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT32} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, IntBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#FLOAT} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, FloatBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#DOUBLE} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, DoubleBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a {@link DataType#INT64} constant with data from the given buffer. + * + * <p>Creates a constant with the given shape by copying elements from the buffer (starting from + * its current position) into the tensor. For example, if {@code shape = {2,3} } (which represents + * a 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this + * method. + * + * @param scope is a scope used to add the underlying operation. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer + */ + public static Constant create(Scope scope, long[] shape, LongBuffer data) { + try (Tensor value = Tensor.create(shape, data)) { + return createWithTensor(scope, value); + } + } + + /** + * Create a constant with data from the given buffer. + * + * <p>Creates a Constant with the provided shape of any type where the constant data has been + * encoded into {@code data} as per the specification of the TensorFlow <a + * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API</a>. + * + * @param scope is a scope used to add the underlying operation. + * @param dataType the tensor datatype. + * @param shape the tensor shape. + * @param data a buffer containing the tensor data. + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the + * buffer + */ + public static Constant create(Scope scope, DataType dataType, long[] shape, ByteBuffer data) { + try (Tensor value = Tensor.create(dataType, shape, data)) { + return createWithTensor(scope, value); + } + } + + private static Constant createWithTensor(Scope scope, Tensor value) { + return new Constant( + scope + .graph() + .opBuilder("Const", scope.makeOpName("Const")) + .setAttr("value", value) + .setAttr("dtype", value.dataType()) + .build()); + } + + @Override + public Output asOutput() { + return output; + } + + private Constant(Operation operation) { + super(operation); + output = operation.output(0); + } + + private final Output output; +} |