// Copyright 2016 The TensorFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package tensorflow_test import ( "archive/zip" "bufio" "flag" "fmt" "io" "io/ioutil" "log" "net/http" "os" "path/filepath" "github.com/tensorflow/tensorflow/tensorflow/go/op" tf "github.com/tensorflow/tensorflow/tensorflow/go" ) func Example() { // An example for using the TensorFlow Go API for image recognition // using a pre-trained inception model (http://arxiv.org/abs/1512.00567). // // Sample usage: -dir=/tmp/modeldir -image=/path/to/some/jpeg // // The pre-trained model takes input in the form of a 4-dimensional // tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ], // where: // - BATCH_SIZE allows for inference of multiple images in one pass through the graph // - IMAGE_HEIGHT is the height of the images on which the model was trained // - IMAGE_WIDTH is the width of the images on which the model was trained // - 3 is the (R, G, B) values of the pixel colors represented as a float. // // And produces as output a vector with shape [ NUM_LABELS ]. // output[i] is the probability that the input image was recognized as // having the i-th label. // // A separate file contains a list of string labels corresponding to the // integer indices of the output. // // This example: // - Loads the serialized representation of the pre-trained model into a Graph // - Creates a Session to execute operations on the Graph // - Converts an image file to a Tensor to provide as input to a Session run // - Executes the Session and prints out the label with the highest probability // // To convert an image file to a Tensor suitable for input to the Inception model, // this example: // - Constructs another TensorFlow graph to normalize the image into a // form suitable for the model (for example, resizing the image) // - Creates an executes a Session to obtain a Tensor in this normalized form. modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary") imagefile := flag.String("image", "", "Path of a JPEG-image to extract labels for") flag.Parse() if *modeldir == "" || *imagefile == "" { flag.Usage() return } // Load the serialized GraphDef from a file. modelfile, labelsfile, err := modelFiles(*modeldir) if err != nil { log.Fatal(err) } model, err := ioutil.ReadFile(modelfile) if err != nil { log.Fatal(err) } // Construct an in-memory graph from the serialized form. graph := tf.NewGraph() if err := graph.Import(model, ""); err != nil { log.Fatal(err) } // Create a session for inference over graph. session, err := tf.NewSession(graph, nil) if err != nil { log.Fatal(err) } defer session.Close() // Run inference on *imageFile. // For multiple images, session.Run() can be called in a loop (and // concurrently). Alternatively, images can be batched since the model // accepts batches of image data as input. tensor, err := makeTensorFromImage(*imagefile) if err != nil { log.Fatal(err) } output, err := session.Run( map[tf.Output]*tf.Tensor{ graph.Operation("input").Output(0): tensor, }, []tf.Output{ graph.Operation("output").Output(0), }, nil) if err != nil { log.Fatal(err) } // output[0].Value() is a vector containing probabilities of // labels for each image in the "batch". The batch size was 1. // Find the most probably label index. probabilities := output[0].Value().([][]float32)[0] printBestLabel(probabilities, labelsfile) } func printBestLabel(probabilities []float32, labelsFile string) { bestIdx := 0 for i, p := range probabilities { if p > probabilities[bestIdx] { bestIdx = i } } // Found the best match. Read the string from labelsFile, which // contains one line per label. file, err := os.Open(labelsFile) if err != nil { log.Fatal(err) } defer file.Close() scanner := bufio.NewScanner(file) var labels []string for scanner.Scan() { labels = append(labels, scanner.Text()) } if err := scanner.Err(); err != nil { log.Printf("ERROR: failed to read %s: %v", labelsFile, err) } fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx]) } // Convert the image in filename to a Tensor suitable as input to the Inception model. func makeTensorFromImage(filename string) (*tf.Tensor, error) { bytes, err := ioutil.ReadFile(filename) if err != nil { return nil, err } // DecodeJpeg uses a scalar String-valued tensor as input. tensor, err := tf.NewTensor(string(bytes)) if err != nil { return nil, err } // Construct a graph to normalize the image graph, input, output, err := constructGraphToNormalizeImage() if err != nil { return nil, err } // Execute that graph to normalize this one image session, err := tf.NewSession(graph, nil) if err != nil { return nil, err } defer session.Close() normalized, err := session.Run( map[tf.Output]*tf.Tensor{input: tensor}, []tf.Output{output}, nil) if err != nil { return nil, err } return normalized[0], nil } // The inception model takes as input the image described by a Tensor in a very // specific normalized format (a particular image size, shape of the input tensor, // normalized pixel values etc.). // // This function constructs a graph of TensorFlow operations which takes as // input a JPEG-encoded string and returns a tensor suitable as input to the // inception model. func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) { // Some constants specific to the pre-trained model at: // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip // // - The model was trained after with images scaled to 224x224 pixels. // - The colors, represented as R, G, B in 1-byte each were converted to // float using (value - Mean)/Scale. const ( H, W = 224, 224 Mean = float32(117) Scale = float32(1) ) // - input is a String-Tensor, where the string the JPEG-encoded image. // - The inception model takes a 4D tensor of shape // [BatchSize, Height, Width, Colors=3], where each pixel is // represented as a triplet of floats // - Apply normalization on each pixel and use ExpandDims to make // this single image be a "batch" of size 1 for ResizeBilinear. s := op.NewScope() input = op.Placeholder(s, tf.String) output = op.Div(s, op.Sub(s, op.ResizeBilinear(s, op.ExpandDims(s, op.Cast(s, op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), tf.Float), op.Const(s.SubScope("make_batch"), int32(0))), op.Const(s.SubScope("size"), []int32{H, W})), op.Const(s.SubScope("mean"), Mean)), op.Const(s.SubScope("scale"), Scale)) graph, err = s.Finalize() return graph, input, output, err } func modelFiles(dir string) (modelfile, labelsfile string, err error) { const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" var ( model = filepath.Join(dir, "tensorflow_inception_graph.pb") labels = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt") zipfile = filepath.Join(dir, "inception5h.zip") ) if filesExist(model, labels) == nil { return model, labels, nil } log.Println("Did not find model in", dir, "downloading from", URL) if err := os.MkdirAll(dir, 0755); err != nil { return "", "", err } if err := download(URL, zipfile); err != nil { return "", "", fmt.Errorf("failed to download %v - %v", URL, err) } if err := unzip(dir, zipfile); err != nil { return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err) } os.Remove(zipfile) return model, labels, filesExist(model, labels) } func filesExist(files ...string) error { for _, f := range files { if _, err := os.Stat(f); err != nil { return fmt.Errorf("unable to stat %s: %v", f, err) } } return nil } func download(URL, filename string) error { resp, err := http.Get(URL) if err != nil { return err } defer resp.Body.Close() file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644) if err != nil { return err } defer file.Close() _, err = io.Copy(file, resp.Body) return err } func unzip(dir, zipfile string) error { r, err := zip.OpenReader(zipfile) if err != nil { return err } defer r.Close() for _, f := range r.File { src, err := f.Open() if err != nil { return err } log.Println("Extracting", f.Name) dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644) if err != nil { return err } if _, err := io.Copy(dst, src); err != nil { return err } dst.Close() } return nil }