aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java')
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java131
1 files changed, 131 insertions, 0 deletions
diff --git a/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
new file mode 100644
index 0000000000..ec23792485
--- /dev/null
+++ b/tensorflow/java/src/test/java/org/tensorflow/op/core/ConstantTest.java
@@ -0,0 +1,131 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.op.core;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertTrue;
+
+import java.io.ByteArrayOutputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.tensorflow.DataType;
+import org.tensorflow.Graph;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.op.Scope;
+
+@RunWith(JUnit4.class)
+public class ConstantTest {
+ private static final float EPSILON = 1e-7f;
+
+ @Test
+ public void createIntBuffer() {
+ int[] ints = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, IntBuffer.wrap(ints));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ int[] actual = new int[ints.length];
+ assertArrayEquals(ints, result.copyTo(actual));
+ }
+ }
+
+ @Test
+ public void createFloatBuffer() {
+ float[] floats = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, FloatBuffer.wrap(floats));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ float[] actual = new float[floats.length];
+ assertArrayEquals(floats, result.copyTo(actual), EPSILON);
+ }
+ }
+
+ @Test
+ public void createDoubleBuffer() {
+ double[] doubles = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, DoubleBuffer.wrap(doubles));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ double[] actual = new double[doubles.length];
+ assertArrayEquals(doubles, result.copyTo(actual), EPSILON);
+ }
+ }
+
+ @Test
+ public void createLongBuffer() {
+ long[] longs = {1, 2, 3, 4};
+ long[] shape = {4};
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, shape, LongBuffer.wrap(longs));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ long[] actual = new long[longs.length];
+ assertArrayEquals(longs, result.copyTo(actual));
+ }
+ }
+
+ @Test
+ public void createStringBuffer() throws IOException {
+
+ byte[] data = {(byte) 1, (byte) 2, (byte) 3, (byte) 4};
+ long[] shape = {};
+
+ // byte arrays (DataType.STRING in Tensorflow) are encoded as an offset in the data buffer,
+ // followed by a varint encoded size, followed by the data.
+ ByteArrayOutputStream baout = new ByteArrayOutputStream();
+ DataOutputStream out = new DataOutputStream(baout);
+ // Offset in array.
+ out.writeLong(0L);
+ // Varint encoded length of buffer.
+ // For any number < 0x80, the varint encoding is simply the number itself.
+ // https://developers.google.com/protocol-buffers/docs/encoding#varints
+ assertTrue(data.length < 0x80);
+ out.write(data.length);
+ out.write(data);
+ out.close();
+ byte[] content = baout.toByteArray();
+
+ try (Graph g = new Graph();
+ Session sess = new Session(g)) {
+ Scope scope = new Scope(g);
+ Constant op = Constant.create(scope, DataType.STRING, shape, ByteBuffer.wrap(content));
+ Tensor result = sess.runner().fetch(op.asOutput()).run().get(0);
+ assertArrayEquals(data, result.bytesValue());
+ }
+ }
+}