diff options
author | Asim Shankar <ashankar@google.com> | 2016-10-14 13:41:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-14 14:52:50 -0700 |
commit | 11ac472e3f311c1d346c7d71abf721b3367645e6 (patch) | |
tree | 2ae134c4867a37b9887af5e195ae47952cba8f0c /tensorflow/go/example_inception_inference_test.go | |
parent | b6d540673cdb9d630df1cefc85865057be8d38b7 (diff) |
go: Automatically download the inception model in the example.
This will make it easier to simply copy-paste the example
code shown on
https://godoc.org/github.com/tensorflow/tensorflow/tensorflow/go
and run it as a standalone program.
Without this change, the reader needs to manually download the model,
unzip it, place the files in a specific location on the filesystem etc.
Change: 136200192
Diffstat (limited to 'tensorflow/go/example_inception_inference_test.go')
-rw-r--r-- | tensorflow/go/example_inception_inference_test.go | 110 |
1 files changed, 92 insertions, 18 deletions
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go index 45a3899620..b58942aefb 100644 --- a/tensorflow/go/example_inception_inference_test.go +++ b/tensorflow/go/example_inception_inference_test.go @@ -15,13 +15,18 @@ package tensorflow_test import ( + "archive/zip" "bufio" + "flag" "fmt" "image" _ "image/jpeg" + "io" "io/ioutil" "log" + "net/http" "os" + "path/filepath" tf "github.com/tensorflow/tensorflow/tensorflow/go" ) @@ -50,22 +55,19 @@ func Example() { // - Creates a Session to execute operations on the Graph // - Converts an image file to a Tensor to provide as input for Graph execution // - Exectues the graph and prints out the label with the highest probability - const ( - // Path to a pre-trained inception model. - // The two files are extracted from a zip archive as so: - /* - curl -L https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip -o /tmp/inception5h.zip - unzip /tmp/inception5h.zip -d /tmp - */ - modelFile = "/tmp/tensorflow_inception_graph.pb" - labelsFile = "/tmp/imagenet_comp_graph_label_strings.txt" - - // Image file to "recognize". - testImageFilename = "/tmp/test.jpg" - ) - + modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary") + imagefile := flag.String("image", "", "Path of the image to extract labels for") + flag.Parse() + if *modeldir == "" || *imagefile == "" { + flag.Usage() + return + } // Load the serialized GraphDef from a file. - model, err := ioutil.ReadFile(modelFile) + modelfile, labelsfile, err := modelFiles(*modeldir) + if err != nil { + log.Fatal(err) + } + model, err := ioutil.ReadFile(modelfile) if err != nil { log.Fatal(err) } @@ -83,11 +85,11 @@ func Example() { } defer session.Close() - // Run inference on testImageFilename. + // Run inference on thestImageFilename. // For multiple images, session.Run() can be called in a loop (and // concurrently). Furthermore, images can be batched together since the // model accepts batches of image data as input. - tensor, err := makeTensorFromImageForInception(testImageFilename) + tensor, err := makeTensorFromImageForInception(*imagefile) if err != nil { log.Fatal(err) } @@ -106,7 +108,7 @@ func Example() { // labels for each image in the "batch". The batch size was 1. // Find the most probably label index. probabilities := output[0].Value().([][]float32)[0] - printBestLabel(probabilities, labelsFile) + printBestLabel(probabilities, labelsfile) } func printBestLabel(probabilities []float32, labelsFile string) { @@ -183,3 +185,75 @@ func makeTensorFromImageForInception(filename string) (*tf.Tensor, error) { } return tf.NewTensor(ret) } + +func modelFiles(dir string) (modelfile, labelsfile string, err error) { + const URL = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip" + var ( + model = filepath.Join(dir, "tensorflow_inception_graph.pb") + labels = filepath.Join(dir, "imagenet_comp_graph_label_strings.txt") + zipfile = filepath.Join(dir, "inception5h.zip") + ) + if filesExist(model, labels) == nil { + return model, labels, nil + } + log.Println("Did not find model in", dir, "downloading from", URL) + if err := os.MkdirAll(dir, 0755); err != nil { + return "", "", err + } + if err := download(URL, zipfile); err != nil { + return "", "", fmt.Errorf("failed to download %v - %v", URL, err) + } + if err := unzip(dir, zipfile); err != nil { + return "", "", fmt.Errorf("failed to extract contents from model archive: %v", err) + } + os.Remove(zipfile) + return model, labels, filesExist(model, labels) +} + +func filesExist(files ...string) error { + for _, f := range files { + if _, err := os.Stat(f); err != nil { + return fmt.Errorf("unable to stat %s: %v", f, err) + } + } + return nil +} + +func download(URL, filename string) error { + resp, err := http.Get(URL) + if err != nil { + return err + } + defer resp.Body.Close() + file, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + return err + } + defer file.Close() + _, err = io.Copy(file, resp.Body) + return err +} + +func unzip(dir, zipfile string) error { + r, err := zip.OpenReader(zipfile) + if err != nil { + return err + } + defer r.Close() + for _, f := range r.File { + src, err := f.Open() + if err != nil { + return err + } + log.Println("Extracting", f.Name) + dst, err := os.OpenFile(filepath.Join(dir, f.Name), os.O_WRONLY|os.O_CREATE, 0644) + if err != nil { + return err + } + if _, err := io.Copy(dst, src); err != nil { + return err + } + dst.Close() + } + return nil +} |