diff options
author | Jared Duke <jdduke@google.com> | 2018-07-10 12:38:34 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 12:42:15 -0700 |
commit | 75c114c8db0f2f4e810ea80aaabac2a710c2c22e (patch) | |
tree | 60ff45ba575ef996f53f0688ea4398ebc3d63eb5 /tensorflow/contrib/lite/java | |
parent | ca1b54a83ae352c41bb285f0a6ecace20f706ac1 (diff) |
More Tensor Java class refactoring
PiperOrigin-RevId: 203993466
Diffstat (limited to 'tensorflow/contrib/lite/java')
4 files changed, 115 insertions, 113 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 072cb26bb2..767a220f8c 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -15,7 +15,6 @@ limitations under the License. package org.tensorflow.lite; -import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.MappedByteBuffer; @@ -205,61 +204,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { } } - /** Returns the type of the data. */ - static DataType dataTypeOf(Object o) { - if (o != null) { - Class<?> c = o.getClass(); - while (c.isArray()) { - c = c.getComponentType(); - } - if (float.class.equals(c)) { - return DataType.FLOAT32; - } else if (int.class.equals(c)) { - return DataType.INT32; - } else if (byte.class.equals(c)) { - return DataType.UINT8; - } else if (long.class.equals(c)) { - return DataType.INT64; - } - } - throw new IllegalArgumentException( - "DataType error: cannot resolve DataType of " + o.getClass().getName()); - } - - /** Returns the shape of an object as an int array. */ - static int[] shapeOf(Object o) { - int size = numDimensions(o); - int[] dimensions = new int[size]; - fillShape(o, 0, dimensions); - return dimensions; - } - - static int numDimensions(Object o) { - if (o == null || !o.getClass().isArray()) { - return 0; - } - if (Array.getLength(o) == 0) { - throw new IllegalArgumentException("Array lengths cannot be 0."); - } - return 1 + numDimensions(Array.get(o, 0)); - } - - static void fillShape(Object o, int dim, int[] shape) { - if (shape == null || dim == shape.length) { - return; - } - final int len = Array.getLength(o); - if (shape[dim] == 0) { - shape[dim] = len; - } else if (shape[dim] != len) { - throw new IllegalArgumentException( - String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); - } - for (int i = 0; i < len; ++i) { - fillShape(Array.get(o, i), dim + 1, shape); - } - } - /** * Gets the last inference duration in nanoseconds. It returns null if there is no previous * inference run or the last inference run failed. diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java index 2c74c82417..2403570c52 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Tensor.java @@ -15,6 +15,7 @@ limitations under the License. package org.tensorflow.lite; +import java.lang.reflect.Array; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; @@ -102,13 +103,70 @@ final class Tensor { if (isByteBuffer(input)) { return null; } - int[] inputShape = NativeInterpreterWrapper.shapeOf(input); + int[] inputShape = shapeOf(input); if (Arrays.equals(shapeCopy, inputShape)) { return null; } return inputShape; } + /** Returns the type of the data. */ + static DataType dataTypeOf(Object o) { + if (o != null) { + Class<?> c = o.getClass(); + while (c.isArray()) { + c = c.getComponentType(); + } + if (float.class.equals(c)) { + return DataType.FLOAT32; + } else if (int.class.equals(c)) { + return DataType.INT32; + } else if (byte.class.equals(c)) { + return DataType.UINT8; + } else if (long.class.equals(c)) { + return DataType.INT64; + } + } + throw new IllegalArgumentException( + "DataType error: cannot resolve DataType of " + o.getClass().getName()); + } + + /** Returns the shape of an object as an int array. */ + static int[] shapeOf(Object o) { + int size = numDimensions(o); + int[] dimensions = new int[size]; + fillShape(o, 0, dimensions); + return dimensions; + } + + /** Returns the number of dimensions of a multi-dimensional array, otherwise 0. */ + static int numDimensions(Object o) { + if (o == null || !o.getClass().isArray()) { + return 0; + } + if (Array.getLength(o) == 0) { + throw new IllegalArgumentException("Array lengths cannot be 0."); + } + return 1 + numDimensions(Array.get(o, 0)); + } + + /** Recursively populates the shape dimensions for a given (multi-dimensional) array. */ + static void fillShape(Object o, int dim, int[] shape) { + if (shape == null || dim == shape.length) { + return; + } + final int len = Array.getLength(o); + if (shape[dim] == 0) { + shape[dim] = len; + } else if (shape[dim] != len) { + throw new IllegalArgumentException( + String.format("Mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim)); + } + for (int i = 0; i < len; ++i) { + fillShape(Array.get(o, i), dim + 1, shape); + } + } + private void throwExceptionIfTypeIsIncompatible(Object o) { if (isByteBuffer(o)) { ByteBuffer oBuffer = (ByteBuffer) o; @@ -121,7 +179,7 @@ final class Tensor { } return; } - DataType oType = NativeInterpreterWrapper.dataTypeOf(o); + DataType oType = dataTypeOf(o); if (oType != dtype) { throw new IllegalArgumentException( String.format( @@ -130,7 +188,7 @@ final class Tensor { dtype, o.getClass().getName(), oType)); } - int[] oShape = NativeInterpreterWrapper.shapeOf(o); + int[] oShape = shapeOf(o); if (!Arrays.equals(oShape, shapeCopy)) { throw new IllegalArgumentException( String.format( diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 46bdecf443..9c4a5acd79 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -432,60 +432,6 @@ public final class NativeInterpreterWrapperTest { } @Test - public void testDataTypeOf() { - float[] testEmtpyArray = {}; - DataType dataType = NativeInterpreterWrapper.dataTypeOf(testEmtpyArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[] testFloatArray = {0.783f, 0.251f}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; - dataType = NativeInterpreterWrapper.dataTypeOf(testFloatArray); - assertThat(dataType).isEqualTo(DataType.FLOAT32); - try { - double[] testDoubleArray = {0.783, 0.251}; - NativeInterpreterWrapper.dataTypeOf(testDoubleArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); - } - try { - Float[] testBoxedArray = {0.783f, 0.251f}; - NativeInterpreterWrapper.dataTypeOf(testBoxedArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); - } - } - - @Test - public void testNumDimensions() { - int scalar = 1; - assertThat(NativeInterpreterWrapper.numDimensions(scalar)).isEqualTo(0); - int[][] array = {{2, 4}, {1, 9}}; - assertThat(NativeInterpreterWrapper.numDimensions(array)).isEqualTo(2); - try { - int[] emptyArray = {}; - NativeInterpreterWrapper.numDimensions(emptyArray); - fail(); - } catch (IllegalArgumentException e) { - assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); - } - } - - @Test - public void testFillShape() { - int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; - int num = NativeInterpreterWrapper.numDimensions(array); - int[] shape = new int[num]; - NativeInterpreterWrapper.fillShape(array, 0, shape); - assertThat(num).isEqualTo(3); - assertThat(shape[0]).isEqualTo(2); - assertThat(shape[1]).isEqualTo(3); - assertThat(shape[2]).isEqualTo(1); - } - - @Test public void testGetInferenceLatency() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); float[] oneD = {1.23f, 6.54f, 7.81f}; diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java index fe5926f6de..71ef044943 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/TensorTest.java @@ -170,4 +170,58 @@ public final class TensorTest { assertThat(tensor.getInputShapeIfDifferent(differentShapeInput)) .isEqualTo(new int[] {1, 8, 8, 3}); } + + @Test + public void testDataTypeOf() { + float[] testEmptyArray = {}; + DataType dataType = Tensor.dataTypeOf(testEmptyArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[] testFloatArray = {0.783f, 0.251f}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + float[][] testMultiDimArray = {testFloatArray, testFloatArray, testFloatArray}; + dataType = Tensor.dataTypeOf(testFloatArray); + assertThat(dataType).isEqualTo(DataType.FLOAT32); + try { + double[] testDoubleArray = {0.783, 0.251}; + Tensor.dataTypeOf(testDoubleArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of"); + } + try { + Float[] testBoxedArray = {0.783f, 0.251f}; + Tensor.dataTypeOf(testBoxedArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("cannot resolve DataType of [Ljava.lang.Float;"); + } + } + + @Test + public void testNumDimensions() { + int scalar = 1; + assertThat(Tensor.numDimensions(scalar)).isEqualTo(0); + int[][] array = {{2, 4}, {1, 9}}; + assertThat(Tensor.numDimensions(array)).isEqualTo(2); + try { + int[] emptyArray = {}; + Tensor.numDimensions(emptyArray); + fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("Array lengths cannot be 0."); + } + } + + @Test + public void testFillShape() { + int[][][] array = {{{23}, {14}, {87}}, {{12}, {42}, {31}}}; + int num = Tensor.numDimensions(array); + int[] shape = new int[num]; + Tensor.fillShape(array, 0, shape); + assertThat(num).isEqualTo(3); + assertThat(shape[0]).isEqualTo(2); + assertThat(shape[1]).isEqualTo(3); + assertThat(shape[2]).isEqualTo(1); + } } |