diff options
author | 2018-08-03 14:05:53 -0700 | |
---|---|---|
committer | 2018-08-03 14:06:00 -0700 | |
commit | 9c5b3ce84d1e795b60e5f86a3c43925734862414 (patch) | |
tree | 1e67b99d3968c0910570225f2a6712b792ab883c /tensorflow | |
parent | d4830a56b4d2dbbe3c54ad0090be645d3b314f45 (diff) | |
parent | e3bc2b0e764cacafb1156bc84299790fd9e60b89 (diff) |
Merge pull request #21092 from karllessard:java-constants
PiperOrigin-RevId: 207319780
Diffstat (limited to 'tensorflow')
8 files changed, 880 insertions, 51 deletions
diff --git a/tensorflow/java/BUILD b/tensorflow/java/BUILD index 7ceba3903d..87e6107c2d 100644 --- a/tensorflow/java/BUILD +++ b/tensorflow/java/BUILD @@ -305,6 +305,19 @@ tf_java_test( ], ) +tf_java_test( + name = "ZerosTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/op/core/ZerosTest.java"], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.op.core.ZerosTest", + deps = [ + ":tensorflow", + ":testutil", + "@junit", + ], +) + filegroup( name = "processor_test_resources", srcs = glob([ diff --git a/tensorflow/java/src/main/java/org/tensorflow/DataType.java b/tensorflow/java/src/main/java/org/tensorflow/DataType.java index 7b92be6d38..516655040b 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/DataType.java +++ b/tensorflow/java/src/main/java/org/tensorflow/DataType.java @@ -17,40 +17,54 @@ package org.tensorflow; import java.util.HashMap; import java.util.Map; + import org.tensorflow.types.UInt8; /** Represents the type of elements in a {@link Tensor} as an enum. */ public enum DataType { /** 32-bit single precision floating point. */ - FLOAT(1), + FLOAT(1, 4), /** 64-bit double precision floating point. */ - DOUBLE(2), + DOUBLE(2, 8), /** 32-bit signed integer. */ - INT32(3), + INT32(3, 4), /** 8-bit unsigned integer. */ - UINT8(4), + UINT8(4, 1), /** * A sequence of bytes. * * <p>TensorFlow uses the STRING type for an arbitrary sequence of bytes. */ - STRING(7), + STRING(7, -1), /** 64-bit signed integer. */ - INT64(9), + INT64(9, 8), /** Boolean. */ - BOOL(10); + BOOL(10, 1); private final int value; + + private final int byteSize; - // The integer value must match the corresponding TF_* value in the TensorFlow C API. - DataType(int value) { + /** + * @param value must match the corresponding TF_* value in the TensorFlow C API. + * @param byteSize size of an element of this type, in bytes, -1 if unknown + */ + DataType(int value, int byteSize) { this.value = value; + this.byteSize = byteSize; + } + + /** + * Returns the size of an element of this type, in bytes, or -1 if element size is variable. + */ + public int byteSize() { + return byteSize; } /** Corresponding value of the TF_DataType enum in the TensorFlow C API. */ diff --git a/tensorflow/java/src/main/java/org/tensorflow/Session.java b/tensorflow/java/src/main/java/org/tensorflow/Session.java index 73324f23e6..a660d25f98 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Session.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Session.java @@ -185,11 +185,20 @@ public final class Session implements AutoCloseable { return this; } - /** Makes {@link #run()} return the Tensor referred to by {@code output}. */ + /** + * Makes {@link #run()} return the Tensor referred to by {@code output}. + */ public Runner fetch(Output<?> output) { outputs.add(output); return this; } + + /** + * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + */ + public Runner fetch(Operand<?> operand) { + return fetch(operand.asOutput()); + } /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor}s. @@ -209,6 +218,13 @@ public final class Session implements AutoCloseable { targets.add(operation); return this; } + + /** + * Make {@link #run()} execute {@code operand}, but not return any evaluated {@link Tensor}s. + */ + public Runner addTarget(Operand<?> operand) { + return addTarget(operand.asOutput().op()); + } /** * (Experimental method): set options (typically for debugging) for this run. diff --git a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java index 24a3775db6..8987253768 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Tensor.java @@ -595,20 +595,11 @@ public final class Tensor<T> implements AutoCloseable { } private static int elemByteSize(DataType dataType) { - switch (dataType) { - case FLOAT: - case INT32: - return 4; - case DOUBLE: - case INT64: - return 8; - case BOOL: - case UINT8: - return 1; - case STRING: + int size = dataType.byteSize(); + if (size < 0) { throw new IllegalArgumentException("STRING tensors do not have a fixed element size"); } - throw new IllegalArgumentException("DataType " + dataType + " is not supported yet"); + return size; } private static void throwExceptionIfNotByteOfByteArrays(Object array) { diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java index de4049f66b..00b6726be3 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Constant.java @@ -15,11 +15,15 @@ limitations under the License. package org.tensorflow.op.core; +import static java.nio.charset.StandardCharsets.UTF_8; + import java.nio.ByteBuffer; import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; +import java.nio.charset.Charset; + import org.tensorflow.DataType; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -32,25 +36,82 @@ import org.tensorflow.op.annotation.Operator; /** An operator producing a constant value. */ @Operator public final class Constant<T> extends PrimitiveOp implements Operand<T> { + /** - * Create a constant from a Java object. + * Creates a constant containing a single {@code int} element. * - * <p>The argument {@code object} is first converted into a Tensor using {@link - * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be - * provided. For example: + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return an integer constant + */ + public static Constant<Integer> create(Scope scope, int data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-1 constant of {@code int} elements. * - * <pre>{@code - * Constant.create(scope, 7); // returns a constant scalar tensor 7 - * }</pre> + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param object a Java object representing the constant. - * @see org.tensorflow.Tensor#create(Object) Tensor.create + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. */ - public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { - try (Tensor<T> value = Tensor.create(object, type)) { - return createWithTensor(scope, value); - } + public static Constant<Integer> create(Scope scope, int[][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-3 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-4 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-5 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][][] data) { + return create(scope, data, Integer.class); + } + + /** + * Creates a rank-6 constant of {@code int} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Integer> create(Scope scope, int[][][][][][] data) { + return create(scope, data, Integer.class); } /** @@ -64,6 +125,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return an integer constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Integer> create(Scope scope, long[] shape, IntBuffer data) { @@ -73,6 +135,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code float} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a float constant + */ + public static Constant<Float> create(Scope scope, float data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-1 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-2 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-3 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-4 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-5 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][][] data) { + return create(scope, data, Float.class); + } + + /** + * Creates a rank-6 constant of {@code float} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Float> create(Scope scope, float[][][][][][] data) { + return create(scope, data, Float.class); + } + + /** * Create a {@link DataType#FLOAT} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -83,6 +222,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a float constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Float> create(Scope scope, long[] shape, FloatBuffer data) { @@ -92,6 +232,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code double} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a double constant + */ + public static Constant<Double> create(Scope scope, double data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-2 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-3 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-4 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-5 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][][] data) { + return create(scope, data, Double.class); + } + + /** + * Creates a rank-6 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Double> create(Scope scope, double[][][][][][] data) { + return create(scope, data, Double.class); + } + + /** * Create a {@link DataType#DOUBLE} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -102,6 +319,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a double constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Double> create(Scope scope, long[] shape, DoubleBuffer data) { @@ -111,6 +329,83 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code long} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a long constant + */ + public static Constant<Long> create(Scope scope, long data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-1 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-2 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-3 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-4 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-5 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][][] data) { + return create(scope, data, Long.class); + } + + /** + * Creates a rank-6 constant of {@code long} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Long> create(Scope scope, long[][][][][][] data) { + return create(scope, data, Long.class); + } + + /** * Create a {@link DataType#INT64} constant with data from the given buffer. * * <p>Creates a constant with the given shape by copying elements from the buffer (starting from @@ -121,6 +416,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ public static Constant<Long> create(Scope scope, long[] shape, LongBuffer data) { @@ -130,6 +426,174 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } /** + * Creates a constant containing a single {@code boolean} element. + * + * @param scope is a scope used to add the underlying operation. + * @param data The value to put into the new constant. + * @return a boolean constant + */ + public static Constant<Boolean> create(Scope scope, boolean data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-1 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-2 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-3 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-4 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-5 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a rank-6 constant of {@code boolean} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the + * new constant will match those of the array. + */ + public static Constant<Boolean> create(Scope scope, boolean[][][][][][] data) { + return create(scope, data, Boolean.class); + } + + /** + * Creates a {@code String} constant using the default, UTF-8 encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant<String> create(Scope scope, String data) { + return create(scope, data, UTF_8); + } + + /** + * Creates a {@code String} constant using a specified encoding. + * + * @param scope is a scope used to add the underlying operation. + * @param charset The encoding from String to bytes. + * @param data The string to put into the new constant. + * @return a string constant + */ + public static Constant<String> create(Scope scope, String data, Charset charset) { + try (Tensor<String> value = Tensor.create(data.getBytes(charset), String.class)) { + return createWithTensor(scope, Tensor.create(data.getBytes(charset), String.class)); + } + } + + /** + * Creates a constant containing a single {@code String} element, represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-1 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-2 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-3 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-4 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][][] data) { + return create(scope, data, String.class); + } + + /** + * Creates a rank-5 constant of {@code String} elements, each represented as an array of {@code byte}s. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. String elements are + * sequences of bytes from the last array dimension. + */ + public static Constant<String> create(Scope scope, byte[][][][][][] data) { + return create(scope, data, String.class); + } + + /** * Create a constant with data from the given buffer. * * <p>Creates a Constant with the provided shape of any type where the constant data has been @@ -141,6 +605,7 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { * @param type the tensor datatype. * @param shape the tensor shape. * @param data a buffer containing the tensor data. + * @return a constant of type `type` * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the * buffer */ @@ -150,6 +615,28 @@ public final class Constant<T> extends PrimitiveOp implements Operand<T> { } } + /** + * Create a constant from a Java object. + * + * <p>The argument {@code object} is first converted into a Tensor using {@link + * org.tensorflow.Tensor#create(Object)}, so only Objects supported by this method must be + * provided. For example: + * + * <pre>{@code + * Constant.create(scope, new int[]{{1, 2}, {3, 4}}, Integer.class); // returns a 2x2 integer matrix + * }</pre> + * + * @param scope is a scope used to add the underlying operation. + * @param object a Java object representing the constant. + * @return a constant of type `type` + * @see org.tensorflow.Tensor#create(Object) Tensor.create + */ + public static <T> Constant<T> create(Scope scope, Object object, Class<T> type) { + try (Tensor<T> value = Tensor.create(object, type)) { + return createWithTensor(scope, value); + } + } + private static <T> Constant<T> createWithTensor(Scope scope, Tensor<T> value) { return new Constant<T>( scope diff --git a/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java new file mode 100644 index 0000000000..b7c6beb9bc --- /dev/null +++ b/tensorflow/java/src/main/java/org/tensorflow/op/core/Zeros.java @@ -0,0 +1,68 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import java.nio.ByteBuffer; + +import org.tensorflow.DataType; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Operator; + +/** + * An operator creating a constant initialized with zeros of the shape given by `dims`. + * + * <p>For example, the following expression + * <pre>{@code ops.zeros(ops.constant(new long[]{2, 2}), Float.class)</pre> + * is the equivalent of + * <pre>{@code ops.fill(ops.constant(new long[]{2, 2}), ops.constant(0.0f))</pre> + * + * @param <T> constant type + */ +@Operator +public class Zeros<T> implements Op, Operand<T> { + + /** + * Creates a zeroed tensor given its type and shape. + * + * @param scope is a scope used to add the underlying operation + * @param dims a 1-D operand that represents the shape of the output tensor + * @param type the output tensor datatype + * @return a constant tensor initialized with zeros + * @throws IllegalArgumentException if the tensor type or shape cannot be initialized with zeros. + */ + public static <T, U extends Number> Zeros<T> create(Scope scope, Operand<U> dims, Class<T> type) { + Scope childScope = scope.withSubScope("Zeros"); // If scope had an op name set, it will prevail on "Zeros" + int zeroSize = DataType.fromClass(type).byteSize(); + if (zeroSize < 0) { + throw new IllegalArgumentException(type.getSimpleName() + " tensors cannot be initialized with zeros"); + } + Constant<T> zero = Constant.create(childScope.withName("Zero"), type, new long[]{}, ByteBuffer.allocate(zeroSize)); + return new Zeros<T>(Fill.create(childScope, dims, zero)); + } + + @Override + public Output<T> asOutput() { + return fill.asOutput(); + } + + private final Fill<T> fill; + + private Zeros(Fill<T> fill) { + this.fill = fill; + } +} diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java index ca54214e06..7d3b26de8d 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow.op.core; import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import java.io.ByteArrayOutputStream; @@ -26,6 +27,7 @@ import java.nio.DoubleBuffer; import java.nio.FloatBuffer; import java.nio.IntBuffer; import java.nio.LongBuffer; + import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,6 +39,20 @@ import org.tensorflow.op.Scope; @RunWith(JUnit4.class) public class ConstantTest { private static final float EPSILON = 1e-7f; + + @Test + public void createInt() { + int value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Integer> op = Constant.create(scope, value); + try (Tensor<Integer> result = sess.runner().fetch(op).run().get(0).expect(Integer.class)) { + assertEquals(value, result.intValue()); + } + } + } @Test public void createIntBuffer() { @@ -47,10 +63,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Integer> op = Constant.create(scope, shape, IntBuffer.wrap(ints)); - Tensor<Integer> result = sess.runner().fetch(op.asOutput()) - .run().get(0).expect(Integer.class); - int[] actual = new int[ints.length]; - assertArrayEquals(ints, result.copyTo(actual)); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + int[] actual = new int[ints.length]; + assertArrayEquals(ints, result.expect(Integer.class).copyTo(actual)); + } + } + } + + @Test + public void createFloat() { + float value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Float> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Float.class).floatValue(), 0.0f); + } } } @@ -63,9 +93,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Float> op = Constant.create(scope, shape, FloatBuffer.wrap(floats)); - Tensor<Float> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Float.class); - float[] actual = new float[floats.length]; - assertArrayEquals(floats, result.copyTo(actual), EPSILON); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + float[] actual = new float[floats.length]; + assertArrayEquals(floats, result.expect(Float.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createDouble() { + double value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Double> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Double.class).doubleValue(), 0.0); + } } } @@ -78,9 +123,24 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Double> op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles)); - Tensor<Double> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Double.class); - double[] actual = new double[doubles.length]; - assertArrayEquals(doubles, result.copyTo(actual), EPSILON); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + double[] actual = new double[doubles.length]; + assertArrayEquals(doubles, result.expect(Double.class).copyTo(actual), EPSILON); + } + } + } + + @Test + public void createLong() { + long value = 1; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Long> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Long.class).longValue()); + } } } @@ -93,15 +153,29 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<Long> op = Constant.create(scope, shape, LongBuffer.wrap(longs)); - Tensor<Long> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(Long.class); - long[] actual = new long[longs.length]; - assertArrayEquals(longs, result.copyTo(actual)); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + long[] actual = new long[longs.length]; + assertArrayEquals(longs, result.expect(Long.class).copyTo(actual)); + } } } @Test - public void createStringBuffer() throws IOException { + public void createBoolean() { + boolean value = true; + + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + Constant<Boolean> op = Constant.create(scope, value); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertEquals(value, result.expect(Boolean.class).booleanValue()); + } + } + } + @Test + public void createStringBuffer() throws IOException { byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4}; long[] shape = {}; @@ -124,8 +198,9 @@ public class ConstantTest { Session sess = new Session(g)) { Scope scope = new Scope(g); Constant<String> op = Constant.create(scope, String.class, shape, ByteBuffer.wrap(content)); - Tensor<String> result = sess.runner().fetch(op.asOutput()).run().get(0).expect(String.class); - assertArrayEquals(data, result.bytesValue()); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + assertArrayEquals(data, result.expect(String.class).bytesValue()); + } } } } diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java new file mode 100644 index 0000000000..cf3910b594 --- /dev/null +++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package org.tensorflow.op.core; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.util.List; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.op.Scope; +import org.tensorflow.types.UInt8; + +@RunWith(JUnit4.class) +public class ZerosTest { + private static final float EPSILON = 1e-7f; + + @Test + public void createIntZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Integer> op = Zeros.create(scope, Constant.create(scope, shape), Integer.class); + try (Tensor<?> result = sess.runner().fetch(op).run().get(0)) { + int[][] actual = result.expect(Integer.class).copyTo(new int[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test + public void createFloatZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Float> op = Zeros.create(scope, Constant.create(scope, shape), Float.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + float[][] actual = result.expect(Float.class).copyTo(new float[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0f, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createDoubleZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Double> op = Zeros.create(scope, Constant.create(scope, shape), Double.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + double[][] actual = result.expect(Double.class).copyTo(new double[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0.0, actual[i][j], EPSILON); + } + } + } + } + } + + @Test + public void createLongZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Long> op = Zeros.create(scope, Constant.create(scope, shape), Long.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + long[][] actual = result.expect(Long.class).copyTo(new long[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0L, actual[i][j]); + } + } + } + } + } + + @Test + public void createBooleanZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Boolean> op = Zeros.create(scope, Constant.create(scope, shape), Boolean.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + boolean[][] actual = result.expect(Boolean.class).copyTo(new boolean[(int)shape[0]][(int)shape[1]]); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertFalse(actual[i][j]); + } + } + } + } + } + + @Test + public void createUInt8Zeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<UInt8> op = Zeros.create(scope, Constant.create(scope, shape), UInt8.class); + try (Tensor<?> result = sess.runner().fetch(op.asOutput()).run().get(0)) { + byte[][] actual = result.expect(UInt8.class).copyTo(new byte[(int)shape[0]][(int)shape[1]]); + result.copyTo(actual); + for (int i = 0; i < actual.length; ++i) { + for (int j = 0; j < actual[i].length; ++j) { + assertEquals(0, actual[i][j]); + } + } + } + } + } + + @Test(expected = IllegalArgumentException.class) + public void cannotCreateStringZeros() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros.create(scope, Constant.create(scope, shape), String.class); + } + } + + @Test + public void operationsComposingZerosAreCorrectlyNamed() { + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {2, 2}; + Zeros<Float> zeros = Zeros.create(scope.withSubScope("test"), Constant.create(scope, shape), Float.class); + List<Tensor<?>> results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + } + } +} |