aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/java/src
diff options
context:
space:
mode:
authorGravatar Martin Wicke <martin.wicke@gmail.com>2017-11-10 12:26:11 -0800
committerGravatar GitHub <noreply@github.com>2017-11-10 12:26:11 -0800
commitd0a5d885d61b837018cb931a4d577289acc826fc (patch)
treedd344e45c4eca857c02746ef50d990a9cd81ea69 /tensorflow/java/src
parent047d7965d2877d7b55f4cdb0d0abdcd733f266a9 (diff)
Revert "Branch 175277161"
Diffstat (limited to 'tensorflow/java/src')
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java43
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/Shape.java32
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFString.java27
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFType.java20
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java30
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/types/Types.java52
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java26
12 files changed, 69 insertions, 311 deletions
diff --git a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
index 499757e8cf..2b431eebf5 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/NativeLibrary.java
@@ -43,6 +43,7 @@ final class NativeLibrary {
private static final boolean DEBUG =
System.getProperty("org.tensorflow.NativeLibrary.DEBUG") != null;
private static final String JNI_LIBNAME = "tensorflow_jni";
+ private static final String FRAMEWORK_LIBNAME = "tensorflow_framework";
public static void load() {
if (isLoaded() || tryLoadLibrary()) {
@@ -58,15 +59,12 @@ final class NativeLibrary {
}
// Native code is not present, perhaps it has been packaged into the .jar file containing this.
// Extract the JNI library itself
- final String jniLibName = System.mapLibraryName(JNI_LIBNAME);
- final String jniResourceName = makeResourceName(jniLibName);
+ final String jniResourceName = makeResourceName(JNI_LIBNAME);
log("jniResourceName: " + jniResourceName);
final InputStream jniResource =
NativeLibrary.class.getClassLoader().getResourceAsStream(jniResourceName);
// Extract the JNI's dependency
- final String frameworkLibName =
- maybeAdjustForMacOS(System.mapLibraryName("tensorflow_framework"));
- final String frameworkResourceName = makeResourceName(frameworkLibName);
+ final String frameworkResourceName = makeResourceName(FRAMEWORK_LIBNAME);
log("frameworkResourceName: " + frameworkResourceName);
final InputStream frameworkResource =
NativeLibrary.class.getClassLoader().getResourceAsStream(frameworkResourceName);
@@ -90,15 +88,12 @@ final class NativeLibrary {
tempPath.deleteOnExit();
final String tempDirectory = tempPath.toString();
if (frameworkResource != null) {
- extractResource(frameworkResource, frameworkLibName, tempDirectory);
+ extractResource(frameworkResource, FRAMEWORK_LIBNAME, tempDirectory);
} else {
- log(
- frameworkResourceName
- + " not found. This is fine assuming "
- + jniResourceName
- + " is not built to depend on it.");
+ log(frameworkResourceName + " not found. This is fine assuming " + jniResourceName
+ + " is not built to depend on it.");
}
- System.load(extractResource(jniResource, jniLibName, tempDirectory));
+ System.load(extractResource(jniResource, JNI_LIBNAME, tempDirectory));
} catch (IOException e) {
throw new UnsatisfiedLinkError(
String.format(
@@ -126,27 +121,9 @@ final class NativeLibrary {
}
}
- private static String maybeAdjustForMacOS(String libFilename) {
- if (!System.getProperty("os.name").contains("OS X")) {
- return libFilename;
- }
- // This is macOS, and the TensorFlow release process might have setup dependencies on
- // libtensorflow_framework.so instead of libtensorflow_framework.dylib. Adjust for that.
- final ClassLoader cl = NativeLibrary.class.getClassLoader();
- if (cl.getResource(makeResourceName(libFilename)) != null) {
- return libFilename;
- }
- // liftensorflow_framework.dylib not found, try libtensorflow_framework.so
- final String suffix = ".dylib";
- if (!libFilename.endsWith(suffix)) {
- return libFilename;
- }
- return libFilename.substring(0, libFilename.length() - suffix.length()) + ".so";
- }
-
private static String extractResource(
InputStream resource, String resourceName, String extractToDirectory) throws IOException {
- final File dst = new File(extractToDirectory, resourceName);
+ final File dst = new File(extractToDirectory, System.mapLibraryName(resourceName));
dst.deleteOnExit();
final String dstPath = dst.toString();
log("extracting native library to: " + dstPath);
@@ -180,7 +157,9 @@ final class NativeLibrary {
}
private static String makeResourceName(String baseName) {
- return "org/tensorflow/native/" + String.format("%s-%s/", os(), architecture()) + baseName;
+ return "org/tensorflow/native/"
+ + String.format("%s-%s/", os(), architecture())
+ + System.mapLibraryName(baseName);
}
private static long copy(InputStream src, File dstFile) throws IOException {
diff --git a/tensorflow/java/src/main/java/org/tensorflow/Shape.java b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
index 9aa92be111..d533c3d480 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/Shape.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/Shape.java
@@ -77,6 +77,24 @@ public final class Shape {
return shape[i];
}
+ @Override
+ public int hashCode() {
+ return Arrays.hashCode(shape);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+
+ if (obj instanceof Shape && Arrays.equals(this.shape, ((Shape) obj).shape)) {
+ return !hasUnknownDimension();
+ }
+
+ return super.equals(obj);
+ }
+
/** Succinct description of the shape meant for debugging. */
@Override
public String toString() {
@@ -98,4 +116,18 @@ public final class Shape {
}
private long[] shape;
+
+ private boolean hasUnknownDimension() {
+ if (shape == null) {
+ return true;
+ }
+
+ for (long dimension : shape) {
+ if (dimension == -1) {
+ return true;
+ }
+ }
+
+ return false;
+ }
}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
deleted file mode 100644
index ab34f6aa12..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFBool.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a boolean. */
-public class TFBool implements TFType {
- private TFBool() {}
- static {
- Types.typeCodes.put(TFBool.class, DataType.BOOL);
- }
- static {
- Types.scalars.put(TFBool.class, false);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
deleted file mode 100644
index 49e5d9f2f3..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 64-bit double precision floating point number. */
-public class TFDouble implements TFType {
- private TFDouble() {}
- static {
- Types.typeCodes.put(TFDouble.class, DataType.DOUBLE);
- }
- static {
- Types.scalars.put(TFDouble.class, 0.0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
deleted file mode 100644
index 8426ee41f0..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFFloat.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 32-bit single precision floating point number. */
-public class TFFloat implements TFType {
- private TFFloat() {}
- static {
- Types.typeCodes.put(TFFloat.class, DataType.FLOAT);
- }
- static {
- Types.scalars.put(TFFloat.class, 0f);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
deleted file mode 100644
index 3947b6ad09..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt32.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 32-bit signed integer. */
-public class TFInt32 implements TFType {
- private TFInt32() {}
- static {
- Types.typeCodes.put(TFInt32.class, DataType.INT32);
- }
- static {
- Types.scalars.put(TFInt32.class, 0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
deleted file mode 100644
index ccdded8693..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFInt64.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents a 64-bit signed integer. */
-public class TFInt64 implements TFType {
- private TFInt64() {}
- static {
- Types.typeCodes.put(TFInt64.class, DataType.INT64);
- }
- static {
- Types.scalars.put(TFInt64.class, 0L);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
deleted file mode 100644
index e7327e8c57..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFString.java
+++ /dev/null
@@ -1,27 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents an arbitrary sequence of bytes. */
-public class TFString implements TFType {
- private TFString() {}
- static {
- Types.typeCodes.put(TFString.class, DataType.STRING);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
deleted file mode 100644
index 562953ac9d..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFType.java
+++ /dev/null
@@ -1,20 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-package org.tensorflow.types;
-
-/**
- * A marker interface for classes representing TensorFlow types.
- */
-public interface TFType {}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java b/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
deleted file mode 100644
index d7305ca5a8..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/TFUInt8.java
+++ /dev/null
@@ -1,30 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// GENERATED FILE. To update, edit tftypes.pl instead.
-
-package org.tensorflow.types;
-
-import org.tensorflow.DataType;
-
-/** Represents an 8-bit unsigned integer. */
-public class TFUInt8 implements TFType {
- private TFUInt8() {}
- static {
- Types.typeCodes.put(TFUInt8.class, DataType.UINT8);
- }
- static {
- Types.scalars.put(TFUInt8.class, (byte)0);
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java b/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
deleted file mode 100644
index 976cd9fd34..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/types/Types.java
+++ /dev/null
@@ -1,52 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-package org.tensorflow.types;
-
-import java.util.HashMap;
-import java.util.Map;
-import org.tensorflow.DataType;
-
-/**
- * Utility class for managing the representation of TensorFlow types as Java
- * types. For each TensorFlow type (e.g., int32), there is a corresponding Java
- * type (e.g., TFInt32) that represents it at compile time and a corresponding
- * class object (e.g., TFInt32.class) that represents it at run time. There is
- * also an enumeration value in DataType that can be used to represent the
- * type, though that should rarely be required.
- */
-public class Types {
-
- private Types() {} // not instantiable
-
- static final Map<Class<?>, DataType> typeCodes = new HashMap<>();
-
- /** Returns the DataType value corresponding to a TensorFlow type class. */
- public static DataType dataType(Class<? extends TFType> c) {
- DataType dtype = typeCodes.get(c);
- if (dtype == null) {
- throw new IllegalArgumentException("" + c + " is not a TensorFlow type.");
- }
- return dtype;
- }
-
- static final Map<Class<?>, Object> scalars = new HashMap<>();
-
- /** Returns the zero value of type described by {@code c}, or null if
- * the type (e.g., string) is not numeric and therefore has no zero value.
- */
- public static Object zeroValue(Class<? extends TFType> c) {
- return scalars.get(c);
- }
-}
diff --git a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
index 3b027700c5..92cc3bd60e 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/ShapeTest.java
@@ -16,6 +16,7 @@ limitations under the License.
package org.tensorflow;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotEquals;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -77,4 +78,29 @@ public class ShapeTest {
assertEquals(5, n.shape().size(1));
}
}
+
+ @Test
+ public void equalsWorksCorrectly() {
+ assertEquals(Shape.scalar(), Shape.scalar());
+ assertEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 3));
+
+ assertNotEquals(Shape.make(1,2), null);
+ assertNotEquals(Shape.make(1,2), new Object());
+ assertNotEquals(Shape.make(1, 2, 3), Shape.make(1, 2, 4));
+
+
+ assertNotEquals(Shape.unknown(), Shape.unknown());
+ assertNotEquals(Shape.make(-1), Shape.make(-1));
+ assertNotEquals(Shape.make(1, -1, 3), Shape.make(1, -1, 3));
+ }
+
+ @Test
+ public void hashCodeIsAsExpected() {
+ assertEquals(Shape.make(1, 2, 3, 4).hashCode(), Shape.make(1, 2, 3, 4).hashCode());
+ assertEquals(Shape.scalar().hashCode(), Shape.scalar().hashCode());
+ assertEquals(Shape.unknown().hashCode(), Shape.unknown().hashCode());
+
+ assertNotEquals(Shape.make(1, 2).hashCode(), Shape.make(1, 3).hashCode());
+ }
}
+