aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/examples/label_image/README.md12
-rw-r--r--tensorflow/java/README.md8
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/BUILD6
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/Example.java29
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java208
5 files changed, 223 insertions, 40 deletions
diff --git a/tensorflow/examples/label_image/README.md b/tensorflow/examples/label_image/README.md
index e427ff7845..62385312b6 100644
--- a/tensorflow/examples/label_image/README.md
+++ b/tensorflow/examples/label_image/README.md
@@ -1,7 +1,10 @@
# TensorFlow C++ Image Recognition Demo
This example shows how you can load a pre-trained TensorFlow network and use it
-to recognize objects in images.
+to recognize objects in images in C++. For Java see the [Java
+README](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java),
+and for Go see the [godoc
+example](https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go#ex-package).
## Description
@@ -10,9 +13,9 @@ in on the command line.
## To build/install/run
-The TensorFlow `GraphDef` that contains the model definition and weights
-is not packaged in the repo because of its size. Instead, you must
-first download the file to the `data` directory in the source tree:
+The TensorFlow `GraphDef` that contains the model definition and weights is not
+packaged in the repo because of its size. Instead, you must first download the
+file to the `data` directory in the source tree:
```bash
$ wget https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip -O tensorflow/examples/label_image/data/inception_dec_2015.zip
@@ -49,6 +52,7 @@ I tensorflow/examples/label_image/main.cc:207] academic gown (896): 0.0232407
I tensorflow/examples/label_image/main.cc:207] bow tie (817): 0.0157355
I tensorflow/examples/label_image/main.cc:207] bolo tie (940): 0.0145023
```
+
In this case, we're using the default image of Admiral Grace Hopper, and you can
see the network correctly spots she's wearing a military uniform, with a high
score of 0.6.
diff --git a/tensorflow/java/README.md b/tensorflow/java/README.md
index d9bee5e342..1eea76c48a 100644
--- a/tensorflow/java/README.md
+++ b/tensorflow/java/README.md
@@ -40,7 +40,7 @@ bazel build -c opt \
//tensorflow/java:libtensorflow-jni
```
-## Example Usage
+## Example
### With bazel
@@ -48,7 +48,7 @@ Add a dependency on `//tensorflow/java:tensorflow` to the `java_binary` or
`java_library` rule. For example:
```sh
-bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
+bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:label_image
```
### With `javac`
@@ -58,7 +58,7 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
```sh
javac \
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar \
- ./src/main/java/org/tensorflow/examples/Example.java
+ ./src/main/java/org/tensorflow/examples/LabelImage.java
```
- Make `libtensorflow.jar` and `libtensorflow-jni.so`
@@ -68,5 +68,5 @@ bazel run -c opt //tensorflow/java/src/main/java/org/tensorflow/examples:example
java \
-Djava.library.path=../../bazel-bin/tensorflow/java \
-cp ../../bazel-bin/tensorflow/java/libtensorflow.jar:./src/main/java \
- org.tensorflow.examples.Example
+ org.tensorflow.examples.LabelImage
```
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
index 529287a038..5f9aefef4c 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/BUILD
@@ -6,9 +6,9 @@ package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
java_binary(
- name = "example",
- srcs = ["Example.java"],
- main_class = "org.tensorflow.examples.Example",
+ name = "label_image",
+ srcs = ["LabelImage.java"],
+ main_class = "org.tensorflow.examples.LabelImage",
deps = ["//tensorflow/java:tensorflow"],
)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java b/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
deleted file mode 100644
index 630632087a..0000000000
--- a/tensorflow/java/src/main/java/org/tensorflow/examples/Example.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-package org.tensorflow.examples;
-
-import org.tensorflow.TensorFlow;
-
-/**
- * Sample usage of the TensorFlow Java library.
- *
- * <p>This sample should become more useful as functionality is added to the API.
- */
-public class Example {
- public static void main(String[] args) {
- System.out.println("TensorFlow version: " + TensorFlow.version());
- }
-}
diff --git a/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
new file mode 100644
index 0000000000..740248a29b
--- /dev/null
+++ b/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java
@@ -0,0 +1,208 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+package org.tensorflow.examples;
+
+import java.io.IOException;
+import java.io.PrintStream;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import org.tensorflow.DataType;
+import org.tensorflow.Graph;
+import org.tensorflow.Output;
+import org.tensorflow.Session;
+import org.tensorflow.Tensor;
+import org.tensorflow.TensorFlow;
+
+/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
+public class LabelImage {
+ private static void printUsage(PrintStream s) {
+ final String url =
+ "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
+ s.println(
+ "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)");
+ s.println("to label JPEG images.");
+ s.println("TensorFlow version: " + TensorFlow.version());
+ s.println();
+ s.println("Usage: label_image <model dir> <image file>");
+ s.println();
+ s.println("Where:");
+ s.println("<model dir> is a directory containing the unzipped contents of the inception model");
+ s.println(" (from " + url + ")");
+ s.println("<image file> is the path to a JPEG image file");
+ }
+
+ public static void main(String[] args) {
+ if (args.length != 2) {
+ printUsage(System.err);
+ System.exit(1);
+ }
+ String modelDir = args[0];
+ String imageFile = args[1];
+
+ byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
+ List<String> labels =
+ readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
+ byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
+
+ try (Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
+ float[] labelProbabilities = executeInceptionGraph(graphDef, image);
+ int bestLabelIdx = maxIndex(labelProbabilities);
+ System.out.println(
+ String.format(
+ "BEST MATCH: %s (%.2f%% likely)",
+ labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
+ }
+ }
+
+ private static Tensor constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
+ try (Graph g = new Graph()) {
+ GraphBuilder b = new GraphBuilder(g);
+ // Some constants specific to the pre-trained model at:
+ // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
+ //
+ // - The model was trained with images scaled to 224x224 pixels.
+ // - The colors, represented as R, G, B in 1-byte each were converted to
+ // float using (value - Mean)/Scale.
+ final int H = 224;
+ final int W = 224;
+ final float mean = 117f;
+ final float scale = 1f;
+
+ // Since the graph is being constructed once per execution here, we can use a constant for the
+ // input image. If the graph were to be re-used for multiple input images, a placeholder would
+ // have been more appropriate.
+ final Output input = b.constant("input", imageBytes);
+ final Output output =
+ b.div(
+ b.sub(
+ b.resizeBilinear(
+ b.expandDims(
+ b.cast(b.decodeJpeg(input, 3), DataType.FLOAT),
+ b.constant("make_batch", 0)),
+ b.constant("size", new int[] {H, W})),
+ b.constant("mean", mean)),
+ b.constant("scale", scale));
+ try (Session s = new Session(g)) {
+ return s.runner().fetch(output.op().name()).run().get(0);
+ }
+ }
+ }
+
+ private static float[] executeInceptionGraph(byte[] graphDef, Tensor image) {
+ try (Graph g = new Graph()) {
+ g.importGraphDef(graphDef);
+ try (Session s = new Session(g);
+ Tensor result = s.runner().feed("input", image).fetch("output").run().get(0)) {
+ final long[] rshape = result.shape();
+ if (result.numDimensions() != 2 || rshape[0] != 1) {
+ throw new RuntimeException(
+ String.format(
+ "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
+ Arrays.toString(rshape)));
+ }
+ int nlabels = (int) rshape[1];
+ return result.copyTo(new float[1][nlabels])[0];
+ }
+ }
+ }
+
+ private static int maxIndex(float[] probabilities) {
+ int best = 0;
+ for (int i = 1; i < probabilities.length; ++i) {
+ if (probabilities[i] > probabilities[best]) {
+ best = i;
+ }
+ }
+ return best;
+ }
+
+ private static byte[] readAllBytesOrExit(Path path) {
+ try {
+ return Files.readAllBytes(path);
+ } catch (IOException e) {
+ System.err.println("Failed to read [" + path + "]: " + e.getMessage());
+ System.exit(1);
+ }
+ return null;
+ }
+
+ private static List<String> readAllLinesOrExit(Path path) {
+ try {
+ return Files.readAllLines(path, Charset.forName("UTF-8"));
+ } catch (IOException e) {
+ System.err.println("Failed to read [" + path + "]: " + e.getMessage());
+ System.exit(0);
+ }
+ return null;
+ }
+
+ // In the fullness of time, equivalents of the methods of this class should be auto-generated from
+ // the OpDefs linked into libtensorflow-jni.so. That would match what is done in other languages
+ // like Python, C++ and Go.
+ static class GraphBuilder {
+ GraphBuilder(Graph g) {
+ this.g = g;
+ }
+
+ Output div(Output x, Output y) {
+ return binaryOp("Div", x, y);
+ }
+
+ Output sub(Output x, Output y) {
+ return binaryOp("Sub", x, y);
+ }
+
+ Output resizeBilinear(Output images, Output size) {
+ return binaryOp("ResizeBilinear", images, size);
+ }
+
+ Output expandDims(Output input, Output dim) {
+ return binaryOp("ExpandDims", input, dim);
+ }
+
+ Output cast(Output value, DataType dtype) {
+ return g.opBuilder("Cast", "Cast").addInput(value).setAttr("DstT", dtype).build().output(0);
+ }
+
+ Output decodeJpeg(Output contents, long channels) {
+ return g.opBuilder("DecodeJpeg", "DecodeJpeg")
+ .addInput(contents)
+ .setAttr("channels", channels)
+ .build()
+ .output(0);
+ }
+
+ Output constant(String name, Object value) {
+ try (Tensor t = Tensor.create(value)) {
+ return g.opBuilder("Const", name)
+ .setAttr("dtype", t.dataType())
+ .setAttr("value", t)
+ .build()
+ .output(0);
+ }
+ }
+
+ private Output binaryOp(String type, Output in1, Output in2) {
+ return g.opBuilder(type, type).addInput(in1).addInput(in2).build().output(0);
+ }
+
+ private Graph g;
+ }
+}