aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/example_inception_inference_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-14 13:41:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 14:52:50 -0700
commit11ac472e3f311c1d346c7d71abf721b3367645e6 (patch)
tree2ae134c4867a37b9887af5e195ae47952cba8f0c /tensorflow/go/example_inception_inference_test.go
parentb6d540673cdb9d630df1cefc85865057be8d38b7 (diff)
go: Automatically download the inception model in the example.
This will make it easier to simply copy-paste the example code shown on https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go and run it as a standalone program. Without this change, the reader needs to manually download the model, unzip it, place the files in a specific location on the filesystem etc. Change: 136200192
Diffstat (limited to 'tensorflow/go/example_inception_inference_test.go')
-rw-r--r--tensorflow/go/example_inception_inference_test.go110
1 files changed, 92 insertions, 18 deletions
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go
index 45a3899620..b58942aefb 100644
--- a/tensorflow/go/example_inception_inference_test.go
+++ b/tensorflow/go/example_inception_inference_test.go
@@ -15,13 +15,18 @@
package tensorflow_test
import (
+ "archive/zip"
"bufio"
+ "flag"
"fmt"
"image"
_ "image/jpeg"
+ "io"
"io/ioutil"
"log"
+ "net/http"
"os"
+ "path/filepath"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
@@ -50,22 +55,19 @@ func Example() {
// - Creates a Session to execute operations on the Graph
// - Converts an image file to a Tensor to provide as input for Graph execution
// - Exectues the graph and prints out the label with the highest probability
- const (
- // Path to a pre-trained inception model.
- // The two files are extracted from a zip archive as so:
- /*
- curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip
- unzip /tmp/inception5h.zip -d /tmp
- */
- modelFile = "/tmp/tensorflow_inception_graph.pb"
- labelsFile = "/tmp/imagenet_comp_graph_label_strings.txt"
-
- // Image file to "recognize".
- testImageFilename = "/tmp/test.jpg"
- )
-
+ modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary")
+ imagefile := flag.String("image", "", "Path of the image to extract labels for")
+ flag.Parse()
+ if *modeldir == "" || *imagefile == "" {
+ flag.Usage()
+ return
+ }
// Load the serialized GraphDef from a file.
- model, err := ioutil.ReadFile(modelFile)
+ modelfile, labelsfile, err := modelFiles(*modeldir)
+ if err != nil {
+ log.Fatal(err)
+ }
+ model, err := ioutil.ReadFile(modelfile)
if err != nil {
log.Fatal(err)
}
@@ -83,11 +85,11 @@ func Example() {
}
defer session.Close()
- // Run inference on testImageFilename.
+ // Run inference on thestImageFilename.
// For multiple images, session.Run() can be called in a loop (and
// concurrently). Furthermore, images can be batched together since the
// model accepts batches of image data as input.
- tensor, err := makeTensorFromImageForInception(testImageFilename)
+ tensor, err := makeTensorFromImageForInception(*imagefile)
if err != nil {
log.Fatal(err)
}
@@ -106,7 +108,7 @@ func Example() {
// labels for each image in the "batch". The batch size was 1.
// Find the most probably label index.
probabilities := output[0].Value().([][]float32)[0]
- printBestLabel(probabilities, labelsFile)
+ printBestLabel(probabilities, labelsfile)
}
func printBestLabel(probabilities []float32, labelsFile string) {
@@ -183,3 +185,75 @@ func makeTensorFromImageForInception(filename string) (*tf.Tensor, error) {
}
return tf.NewTensor(ret)
}
+
+func modelFiles(dir string) (modelfile, labelsfile string, err error) {
+ const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip"
+ var (
+ model = filepath.Join(dir, "tensorflow_inception_graph.pb")
+ labels = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt")
+ zipfile = filepath.Join(dir, "inception5h.zip")
+ )
+ if filesExist(model, labels) == nil {
+ return model, labels, nil
+ }
+ log.Println("Did not find model in", dir, "downloading from", URL)
+ if err := os.MkdirAll(dir, 0755); err != nil {
+ return "", "", err
+ }
+ if err := download(URL, zipfile); err != nil {
+ return "", "", fmt.Errorf("failed to download %v - %v", URL, err)
+ }
+ if err := unzip(dir, zipfile); err != nil {
+ return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err)
+ }
+ os.Remove(zipfile)
+ return model, labels, filesExist(model, labels)
+}
+
+func filesExist(files ...string) error {
+ for _, f := range files {
+ if _, err := os.Stat(f); err != nil {
+ return fmt.Errorf("unable to stat %s: %v", f, err)
+ }
+ }
+ return nil
+}
+
+func download(URL, filename string) error {
+ resp, err := http.Get(URL)
+ if err != nil {
+ return err
+ }
+ defer resp.Body.Close()
+ file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+ _, err = io.Copy(file, resp.Body)
+ return err
+}
+
+func unzip(dir, zipfile string) error {
+ r, err := zip.OpenReader(zipfile)
+ if err != nil {
+ return err
+ }
+ defer r.Close()
+ for _, f := range r.File {
+ src, err := f.Open()
+ if err != nil {
+ return err
+ }
+ log.Println("Extracting", f.Name)
+ dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644)
+ if err != nil {
+ return err
+ }
+ if _, err := io.Copy(dst, src); err != nil {
+ return err
+ }
+ dst.Close()
+ }
+ return nil
+}