aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-10 12:38:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 12:42:15 -0700
commit75c114c8db0f2f4e810ea80aaabac2a710c2c22e (patch)
tree60ff45ba575ef996f53f0688ea4398ebc3d63eb5 /tensorflow/contrib/lite/java
parentca1b54a83ae352c41bb285f0a6ecace20f706ac1 (diff)
More Tensor Java class refactoring
PiperOrigin-RevId: 203993466
Diffstat (limited to 'tensorflow/contrib/lite/java')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java56
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java64
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java54
-rw-r--r--tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java54
4 files changed, 115 insertions, 113 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index 072cb26bb2..767a220f8c 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -15,7 +15,6 @@ limitations under the License.
package org.tensorflow.lite;
-import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
@@ -205,61 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable {
}
}
- /** Returns the type of the data. */
- static DataType dataTypeOf(Object o) {
- if (o != null) {
- Class<?> c = o.getClass();
- while (c.isArray()) {
- c = c.getComponentType();
- }
- if (float.class.equals(c)) {
- return DataType.FLOAT32;
- } else if (int.class.equals(c)) {
- return DataType.INT32;
- } else if (byte.class.equals(c)) {
- return DataType.UINT8;
- } else if (long.class.equals(c)) {
- return DataType.INT64;
- }
- }
- throw new IllegalArgumentException(
- "DataType error: cannot resolve DataType of " + o.getClass().getName());
- }
-
- /** Returns the shape of an object as an int array. */
- static int[] shapeOf(Object o) {
- int size = numDimensions(o);
- int[] dimensions = new int[size];
- fillShape(o, 0, dimensions);
- return dimensions;
- }
-
- static int numDimensions(Object o) {
- if (o == null || !o.getClass().isArray()) {
- return 0;
- }
- if (Array.getLength(o) == 0) {
- throw new IllegalArgumentException("Array lengths cannot be 0.");
- }
- return 1 + numDimensions(Array.get(o, 0));
- }
-
- static void fillShape(Object o, int dim, int[] shape) {
- if (shape == null || dim == shape.length) {
- return;
- }
- final int len = Array.getLength(o);
- if (shape[dim] == 0) {
- shape[dim] = len;
- } else if (shape[dim] != len) {
- throw new IllegalArgumentException(
- String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
- }
- for (int i = 0; i < len; ++i) {
- fillShape(Array.get(o, i), dim + 1, shape);
- }
- }
-
/**
* Gets the last inference duration in nanoseconds. It returns null if there is no previous
* inference run or the last inference run failed.
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
index 2c74c82417..2403570c52 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java
@@ -15,6 +15,7 @@ limitations under the License.
package org.tensorflow.lite;
+import java.lang.reflect.Array;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
@@ -102,13 +103,70 @@ final class Tensor {
if (isByteBuffer(input)) {
return null;
}
- int[] inputShape = NativeInterpreterWrapper.shapeOf(input);
+ int[] inputShape = shapeOf(input);
if (Arrays.equals(shapeCopy, inputShape)) {
return null;
}
return inputShape;
}
+ /** Returns the type of the data. */
+ static DataType dataTypeOf(Object o) {
+ if (o != null) {
+ Class<?> c = o.getClass();
+ while (c.isArray()) {
+ c = c.getComponentType();
+ }
+ if (float.class.equals(c)) {
+ return DataType.FLOAT32;
+ } else if (int.class.equals(c)) {
+ return DataType.INT32;
+ } else if (byte.class.equals(c)) {
+ return DataType.UINT8;
+ } else if (long.class.equals(c)) {
+ return DataType.INT64;
+ }
+ }
+ throw new IllegalArgumentException(
+ "DataType error: cannot resolve DataType of " + o.getClass().getName());
+ }
+
+ /** Returns the shape of an object as an int array. */
+ static int[] shapeOf(Object o) {
+ int size = numDimensions(o);
+ int[] dimensions = new int[size];
+ fillShape(o, 0, dimensions);
+ return dimensions;
+ }
+
+ /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */
+ static int numDimensions(Object o) {
+ if (o == null || !o.getClass().isArray()) {
+ return 0;
+ }
+ if (Array.getLength(o) == 0) {
+ throw new IllegalArgumentException("Array lengths cannot be 0.");
+ }
+ return 1 + numDimensions(Array.get(o, 0));
+ }
+
+ /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */
+ static void fillShape(Object o, int dim, int[] shape) {
+ if (shape == null || dim == shape.length) {
+ return;
+ }
+ final int len = Array.getLength(o);
+ if (shape[dim] == 0) {
+ shape[dim] = len;
+ } else if (shape[dim] != len) {
+ throw new IllegalArgumentException(
+ String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
+ }
+ for (int i = 0; i < len; ++i) {
+ fillShape(Array.get(o, i), dim + 1, shape);
+ }
+ }
+
private void throwExceptionIfTypeIsIncompatible(Object o) {
if (isByteBuffer(o)) {
ByteBuffer oBuffer = (ByteBuffer) o;
@@ -121,7 +179,7 @@ final class Tensor {
}
return;
}
- DataType oType = NativeInterpreterWrapper.dataTypeOf(o);
+ DataType oType = dataTypeOf(o);
if (oType != dtype) {
throw new IllegalArgumentException(
String.format(
@@ -130,7 +188,7 @@ final class Tensor {
dtype, o.getClass().getName(), oType));
}
- int[] oShape = NativeInterpreterWrapper.shapeOf(o);
+ int[] oShape = shapeOf(o);
if (!Arrays.equals(oShape, shapeCopy)) {
throw new IllegalArgumentException(
String.format(
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
index 46bdecf443..9c4a5acd79 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java
@@ -432,60 +432,6 @@ public final class NativeInterpreterWrapperTest {
}
@Test
- public void testDataTypeOf() {
- float[] testEmtpyArray = {};
- DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[] testFloatArray = {0.783f, 0.251f};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
- dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray);
- assertThat(dataType).isEqualTo(DataType.FLOAT32);
- try {
- double[] testDoubleArray = {0.783, 0.251};
- NativeInterpreterWrapper.dataTypeOf(testDoubleArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
- }
- try {
- Float[] testBoxedArray = {0.783f, 0.251f};
- NativeInterpreterWrapper.dataTypeOf(testBoxedArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
- }
- }
-
- @Test
- public void testNumDimensions() {
- int scalar = 1;
- assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0);
- int[][] array = {{2, 4}, {1, 9}};
- assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2);
- try {
- int[] emptyArray = {};
- NativeInterpreterWrapper.numDimensions(emptyArray);
- fail();
- } catch (IllegalArgumentException e) {
- assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
- }
- }
-
- @Test
- public void testFillShape() {
- int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
- int num = NativeInterpreterWrapper.numDimensions(array);
- int[] shape = new int[num];
- NativeInterpreterWrapper.fillShape(array, 0, shape);
- assertThat(num).isEqualTo(3);
- assertThat(shape[0]).isEqualTo(2);
- assertThat(shape[1]).isEqualTo(3);
- assertThat(shape[2]).isEqualTo(1);
- }
-
- @Test
public void testGetInferenceLatency() {
NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH);
float[] oneD = {1.23f, 6.54f, 7.81f};
diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
index fe5926f6de..71ef044943 100644
--- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
+++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java
@@ -170,4 +170,58 @@ public final class TensorTest {
assertThat(tensor.getInputShapeIfDifferent(differentShapeInput))
.isEqualTo(new int[] {1, 8, 8, 3});
}
+
+ @Test
+ public void testDataTypeOf() {
+ float[] testEmptyArray = {};
+ DataType dataType = Tensor.dataTypeOf(testEmptyArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[] testFloatArray = {0.783f, 0.251f};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray};
+ dataType = Tensor.dataTypeOf(testFloatArray);
+ assertThat(dataType).isEqualTo(DataType.FLOAT32);
+ try {
+ double[] testDoubleArray = {0.783, 0.251};
+ Tensor.dataTypeOf(testDoubleArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of");
+ }
+ try {
+ Float[] testBoxedArray = {0.783f, 0.251f};
+ Tensor.dataTypeOf(testBoxedArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;");
+ }
+ }
+
+ @Test
+ public void testNumDimensions() {
+ int scalar = 1;
+ assertThat(Tensor.numDimensions(scalar)).isEqualTo(0);
+ int[][] array = {{2, 4}, {1, 9}};
+ assertThat(Tensor.numDimensions(array)).isEqualTo(2);
+ try {
+ int[] emptyArray = {};
+ Tensor.numDimensions(emptyArray);
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertThat(e).hasMessageThat().contains("Array lengths cannot be 0.");
+ }
+ }
+
+ @Test
+ public void testFillShape() {
+ int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}};
+ int num = Tensor.numDimensions(array);
+ int[] shape = new int[num];
+ Tensor.fillShape(array, 0, shape);
+ assertThat(num).isEqualTo(3);
+ assertThat(shape[0]).isEqualTo(2);
+ assertThat(shape[1]).isEqualTo(3);
+ assertThat(shape[2]).isEqualTo(1);
+ }
}