aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/example_inception_inference_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-09-28 10:37:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-28 11:48:45 -0700
commit7cacfdf03872c36cfc3f43e75e8c342234160e3a (patch)
tree0856b1decb0a1f36a68450c9a5d9b825de999c06 /tensorflow/go/example_inception_inference_test.go
parentb410f52f6616e281d8274317938a258605ad993a (diff)
go: Add an example.
Add an example (that will appear in Go doc) for a real use of the Go TensorFlow APIs - using a pre-defined image recognition model for inference. While at it a couple of minor tweaks: - NewSession now accepts 'nil' for options as the documentation says it does - Convenience accessors for the Outputs of an Operation - Ability to extract (possibly partial) shapes from an Output Another step towards #10 Change: 134560938
Diffstat (limited to 'tensorflow/go/example_inception_inference_test.go')
-rw-r--r--tensorflow/go/example_inception_inference_test.go183
1 files changed, 183 insertions, 0 deletions
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go
new file mode 100644
index 0000000000..abed215d02
--- /dev/null
+++ b/tensorflow/go/example_inception_inference_test.go
@@ -0,0 +1,183 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "bufio"
+ "fmt"
+ "image"
+ _ "image/jpeg"
+ "io/ioutil"
+ "log"
+ "os"
+)
+
+func Example() {
+ // An example for using the TensorFlow Go API for image recognition
+ // using a pre-trained inception model (http://arxiv.org/abs/1512.00567).
+ //
+ // The pre-trained model takes input in the form of a 4-dimensional
+ // tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ],
+ // where:
+ // - BATCH_SIZE allows for inference of multiple images in one pass through the graph
+ // - IMAGE_HEIGHT is the height of the images on which the model was trained
+ // - IMAGE_WIDTH is the width of the images on which the model was trained
+ // - 3 is the (R, G, B) values of the pixel colors represented as a float.
+ //
+ // And produces as output a vector with shape [ NUM_LABELS ].
+ // output[i] is the probability that the input image was recognized as
+ // having the i-th label.
+ //
+ // A separate file contains a list of string labels corresponding to the
+ // integer indices of the output.
+ //
+ // This example:
+ // - Loads the serialized representation of the pre-trained model into a Graph
+ // - Creates a Session to execute operations on the Graph
+ // - Converts an image file to a Tensor to provide as input for Graph execution
+ // - Exectues the graph and prints out the label with the highest probability
+ const (
+ // Path to a pre-trained inception model.
+ // The two files are extracted from a zip archive as so:
+ /*
+ curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip
+ unzip /tmp/inception5h.zip -d /tmp
+ */
+ modelFile = "/tmp/tensorflow_inception_graph.pb"
+ labelsFile = "/tmp/imagenet_comp_graph_label_strings.txt"
+
+ // Image file to "recognize".
+ testImageFilename = "/tmp/test.jpg"
+ )
+
+ // Load the serialized GraphDef from a file.
+ model, err := ioutil.ReadFile(modelFile)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ // Construct an in-memory graph from the serialized form.
+ graph := NewGraph()
+ if err := graph.Import(model, ""); err != nil {
+ log.Fatal(err)
+ }
+
+ // Create a session for inference over graph.
+ session, err := NewSession(graph, nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer session.Close()
+
+ // Run inference on testImageFilename.
+ // For multiple images, session.Run() can be called in a loop (and
+ // concurrently). Furthermore, images can be batched together since the
+ // model accepts batches of image data as input.
+ tensor, err := makeTensorFromImageForInception(testImageFilename)
+ if err != nil {
+ log.Fatal(err)
+ }
+ output, err := session.Run(
+ map[Output]*Tensor{
+ graph.Operation("input").Output(0): tensor,
+ },
+ []Output{
+ graph.Operation("output").Output(0),
+ },
+ nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+ // output[0].Value() is a vector containing probabilities of
+ // labels for each image in the "batch". The batch size was 1.
+ // Find the most probably label index.
+ probabilities := output[0].Value().([][]float32)[0]
+ printBestLabel(probabilities, labelsFile)
+}
+
+func printBestLabel(probabilities []float32, labelsFile string) {
+ bestIdx := 0
+ for i, p := range probabilities {
+ if p > probabilities[bestIdx] {
+ bestIdx = i
+ }
+ }
+ // Found a best match, now read the string from the labelsFile where
+ // there is one line per label.
+ file, err := os.Open(labelsFile)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer file.Close()
+ scanner := bufio.NewScanner(file)
+ var labels []string
+ for scanner.Scan() {
+ labels = append(labels, scanner.Text())
+ }
+ if err := scanner.Err(); err != nil {
+ log.Printf("ERROR: failed to read %s: %v", labelsFile, err)
+ }
+ fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx])
+}
+
+// Given an image stored in filename, returns a Tensor which is suitable for
+// providing the image data to the pre-defined model.
+func makeTensorFromImageForInception(filename string) (*Tensor, error) {
+ const (
+ // Some constants specific to the pre-trained model at:
+ // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
+ //
+ // - The model was trained after with images scaled to 224x224 pixels.
+ // - The colors, represented as R, G, B in 1-byte each were converted to
+ // float using (value - Mean)/Std.
+ //
+ // If using a different pre-trained model, the values will have to be adjusted.
+ H, W = 224, 224
+ Mean = 117
+ Std = float32(1)
+ )
+ file, err := os.Open(filename)
+ if err != nil {
+ return nil, err
+ }
+ defer file.Close()
+ img, _, err := image.Decode(file)
+ if err != nil {
+ return nil, err
+ }
+ sz := img.Bounds().Size()
+ if sz.X != W || sz.Y != H {
+ return nil, fmt.Errorf("input image is required to be %dx%d pixels, was %dx%d", W, H, sz.X, sz.Y)
+ }
+ // 4-dimensional input:
+ // - 1st dimension: Batch size (the model takes a batch of images as
+ // input, here the "batch size" is 1)
+ // - 2nd dimension: Rows of the image
+ // - 3rd dimension: Columns of the row
+ // - 4th dimension: Colors of the pixel as (B, G, R)
+ // Thus, the shape is [1, 224, 224, 3]
+ var ret [1][H][W][3]float32
+ for y := 0; y < H; y++ {
+ for x := 0; x < W; x++ {
+ px := x + img.Bounds().Min.X
+ py := y + img.Bounds().Min.Y
+ r, g, b, _ := img.At(px, py).RGBA()
+ ret[0][y][x][0] = float32((int(b>>8) - Mean)) / Std
+ ret[0][y][x][1] = float32((int(g>>8) - Mean)) / Std
+ ret[0][y][x][2] = float32((int(r>>8) - Mean)) / Std
+ }
+ }
+ return NewTensor(ret)
+}