diff options
author | 2017-11-10 12:26:11 -0800 | |
---|---|---|
committer | 2017-11-10 12:26:11 -0800 | |
commit | d0a5d885d61b837018cb931a4d577289acc826fc (patch) | |
tree | dd344e45c4eca857c02746ef50d990a9cd81ea69 /tensorflow/java/src | |
parent | 047d7965d2877d7b55f4cdb0d0abdcd733f266a9 (diff) |
Revert "Branch 175277161"
Diffstat (limited to 'tensorflow/java/src')
12 files changed, 69 insertions, 311 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java index 499757e8cf..2b431eebf5 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java +++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java @@ -43,6 +43,7 @@ final class NativeLibrary { private static final boolean DEBUG = System.getProperty("org.tensorflow.NativeLibrary.DEBUG") != null; private static final String JNI_LIBNAME = "tensorflow_jni"; + private static final String FRAMEWORK_LIBNAME = "tensorflow_framework"; public static void load() { if (isLoaded() || tryLoadLibrary()) { @@ -58,15 +59,12 @@ final class NativeLibrary { } // Native code is not present, perhaps it has been packaged into the .jar file containing this. // Extract the JNI library itself - final String jniLibName = System.mapLibraryName(JNI_LIBNAME); - final String jniResourceName = makeResourceName(jniLibName); + final String jniResourceName = makeResourceName(JNI_LIBNAME); log("jniResourceName: " + jniResourceName); final InputStream jniResource = NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName); // Extract the JNI's dependency - final String frameworkLibName = - maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework")); - final String frameworkResourceName = makeResourceName(frameworkLibName); + final String frameworkResourceName = makeResourceName(FRAMEWORK_LIBNAME); log("frameworkResourceName: " + frameworkResourceName); final InputStream frameworkResource = NativeLibrary.class.getClassLoader().getResourceAsStream(frameworkResourceName); @@ -90,15 +88,12 @@ final class NativeLibrary { tempPath.deleteOnExit(); final String tempDirectory = tempPath.toString(); if (frameworkResource != null) { - extractResource(frameworkResource, frameworkLibName, tempDirectory); + extractResource(frameworkResource, FRAMEWORK_LIBNAME, tempDirectory); } else { - log( - frameworkResourceName - + " not found. This is fine assuming " - + jniResourceName - + " is not built to depend on it."); + log(frameworkResourceName + " not found. This is fine assuming " + jniResourceName + + " is not built to depend on it."); } - System.load(extractResource(jniResource, jniLibName, tempDirectory)); + System.load(extractResource(jniResource, JNI_LIBNAME, tempDirectory)); } catch (IOException e) { throw new UnsatisfiedLinkError( String.format( @@ -126,27 +121,9 @@ final class NativeLibrary { } } - private static String maybeAdjustForMacOS(String libFilename) { - if (!System.getProperty("os.name").contains("OS X")) { - return libFilename; - } - // This is macOS, and the TensorFlow release process might have setup dependencies on - // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that. - final ClassLoader cl = NativeLibrary.class.getClassLoader(); - if (cl.getResource(makeResourceName(libFilename)) != null) { - return libFilename; - } - // liftensorflow_framework.dylib not found, try libtensorflow_framework.so - final String suffix = ".dylib"; - if (!libFilename.endsWith(suffix)) { - return libFilename; - } - return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so"; - } - private static String extractResource( InputStream resource, String resourceName, String extractToDirectory) throws IOException { - final File dst = new File(extractToDirectory, resourceName); + final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName)); dst.deleteOnExit(); final String dstPath = dst.toString(); log("extracting native library to: " + dstPath); @@ -180,7 +157,9 @@ final class NativeLibrary { } private static String makeResourceName(String baseName) { - return "org/tensorflow/native/" + String.format("%s-%s/", os(), architecture()) + baseName; + return "org/tensorflow/native/" + + String.format("%s-%s/", os(), architecture()) + + System.mapLibraryName(baseName); } private static long copy(InputStream src, File dstFile) throws IOException { diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java index 9aa92be111..d533c3d480 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java +++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java @@ -77,6 +77,24 @@ public final class Shape { return shape[i]; } + @Override + public int hashCode() { + return Arrays.hashCode(shape); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) { + return !hasUnknownDimension(); + } + + return super.equals(obj); + } + /** Succinct description of the shape meant for debugging. */ @Override public String toString() { @@ -98,4 +116,18 @@ public final class Shape { } private long[] shape; + + private boolean hasUnknownDimension() { + if (shape == null) { + return true; + } + + for (long dimension : shape) { + if (dimension == -1) { + return true; + } + } + + return false; + } } diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java deleted file mode 100644 index ab34f6aa12..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a boolean. */ -public class TFBool implements TFType { - private TFBool() {} - static { - Types.typeCodes.put(TFBool.class, DataType.BOOL); - } - static { - Types.scalars.put(TFBool.class, false); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java deleted file mode 100644 index 49e5d9f2f3..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 64-bit double precision floating point number. */ -public class TFDouble implements TFType { - private TFDouble() {} - static { - Types.typeCodes.put(TFDouble.class, DataType.DOUBLE); - } - static { - Types.scalars.put(TFDouble.class, 0.0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java deleted file mode 100644 index 8426ee41f0..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 32-bit single precision floating point number. */ -public class TFFloat implements TFType { - private TFFloat() {} - static { - Types.typeCodes.put(TFFloat.class, DataType.FLOAT); - } - static { - Types.scalars.put(TFFloat.class, 0f); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java deleted file mode 100644 index 3947b6ad09..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 32-bit signed integer. */ -public class TFInt32 implements TFType { - private TFInt32() {} - static { - Types.typeCodes.put(TFInt32.class, DataType.INT32); - } - static { - Types.scalars.put(TFInt32.class, 0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java deleted file mode 100644 index ccdded8693..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents a 64-bit signed integer. */ -public class TFInt64 implements TFType { - private TFInt64() {} - static { - Types.typeCodes.put(TFInt64.class, DataType.INT64); - } - static { - Types.scalars.put(TFInt64.class, 0L); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java deleted file mode 100644 index e7327e8c57..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents an arbitrary sequence of bytes. */ -public class TFString implements TFType { - private TFString() {} - static { - Types.typeCodes.put(TFString.class, DataType.STRING); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java deleted file mode 100644 index 562953ac9d..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java +++ /dev/null @@ -1,20 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.types; - -/** - * A marker interface for classes representing TensorFlow types. - */ -public interface TFType {} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java deleted file mode 100644 index d7305ca5a8..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. - -package org.tensorflow.types; - -import org.tensorflow.DataType; - -/** Represents an 8-bit unsigned integer. */ -public class TFUInt8 implements TFType { - private TFUInt8() {} - static { - Types.typeCodes.put(TFUInt8.class, DataType.UINT8); - } - static { - Types.scalars.put(TFUInt8.class, (byte)0); - } -} diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java deleted file mode 100644 index 976cd9fd34..0000000000 --- a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.types; - -import java.util.HashMap; -import java.util.Map; -import org.tensorflow.DataType; - -/** - * Utility class for managing the representation of TensorFlow types as Java - * types. For each TensorFlow type (e.g., int32), there is a corresponding Java - * type (e.g., TFInt32) that represents it at compile time and a corresponding - * class object (e.g., TFInt32.class) that represents it at run time. There is - * also an enumeration value in DataType that can be used to represent the - * type, though that should rarely be required. - */ -public class Types { - - private Types() {} // not instantiable - - static final Map<Class<?>, DataType> typeCodes = new HashMap<>(); - - /** Returns the DataType value corresponding to a TensorFlow type class. */ - public static DataType dataType(Class<? extends TFType> c) { - DataType dtype = typeCodes.get(c); - if (dtype == null) { - throw new IllegalArgumentException("" + c + " is not a TensorFlow type."); - } - return dtype; - } - - static final Map<Class<?>, Object> scalars = new HashMap<>(); - - /** Returns the zero value of type described by {@code c}, or null if - * the type (e.g., string) is not numeric and therefore has no zero value. - */ - public static Object zeroValue(Class<? extends TFType> c) { - return scalars.get(c); - } -} diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java index 3b027700c5..92cc3bd60e 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java @@ -16,6 +16,7 @@ limitations under the License. package org.tensorflow; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import org.junit.Test; import org.junit.runner.RunWith; @@ -77,4 +78,29 @@ public class ShapeTest { assertEquals(5, n.shape().size(1)); } } + + @Test + public void equalsWorksCorrectly() { + assertEquals(Shape.scalar(), Shape.scalar()); + assertEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 3)); + + assertNotEquals(Shape.make(1,2), null); + assertNotEquals(Shape.make(1,2), new Object()); + assertNotEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 4)); + + + assertNotEquals(Shape.unknown(), Shape.unknown()); + assertNotEquals(Shape.make(-1), Shape.make(-1)); + assertNotEquals(Shape.make(1, -1, 3), Shape.make(1, -1, 3)); + } + + @Test + public void hashCodeIsAsExpected() { + assertEquals(Shape.make(1, 2, 3, 4).hashCode(), Shape.make(1, 2, 3, 4).hashCode()); + assertEquals(Shape.scalar().hashCode(), Shape.scalar().hashCode()); + assertEquals(Shape.unknown().hashCode(), Shape.unknown().hashCode()); + + assertNotEquals(Shape.make(1, 2).hashCode(), Shape.make(1, 3).hashCode()); + } } + |