aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/go/example_inception_inference_test.go65
-rw-r--r--tensorflow/go/genop/internal/genop.go2
-rw-r--r--tensorflow/go/genop/internal/genop_test.go2
-rw-r--r--tensorflow/go/session_test.go40
-rw-r--r--tensorflow/go/tensor.go160
-rw-r--r--tensorflow/go/tensor_test.go8
6 files changed, 200 insertions, 77 deletions
diff --git a/tensorflow/go/example_inception_inference_test.go b/tensorflow/go/example_inception_inference_test.go
index 09c7004468..51ad652e78 100644
--- a/tensorflow/go/example_inception_inference_test.go
+++ b/tensorflow/go/example_inception_inference_test.go
@@ -19,8 +19,6 @@ import (
"bufio"
"flag"
"fmt"
- "image"
- _ "image/jpeg"
"io"
"io/ioutil"
"log"
@@ -36,6 +34,8 @@ func Example() {
// An example for using the TensorFlow Go API for image recognition
// using a pre-trained inception model (http://arxiv.org/abs/1512.00567).
//
+ // Sample usage: <program> -dir=/tmp/modeldir -image=/path/to/some/jpeg
+ //
// The pre-trained model takes input in the form of a 4-dimensional
// tensor with shape [ BATCH_SIZE, IMAGE_HEIGHT, IMAGE_WIDTH, 3 ],
// where:
@@ -63,7 +63,7 @@ func Example() {
// form suitable for the model (for example, resizing the image)
// - Creates an executes a Session to obtain a Tensor in this normalized form.
modeldir := flag.String("dir", "", "Directory containing the trained model files. The directory will be created and the model downloaded into it if necessary")
- imagefile := flag.String("image", "", "Path of the image to extract labels for")
+ imagefile := flag.String("image", "", "Path of a JPEG-image to extract labels for")
flag.Parse()
if *modeldir == "" || *imagefile == "" {
flag.Usage()
@@ -92,10 +92,10 @@ func Example() {
}
defer session.Close()
- // Run inference on thestImageFilename.
+ // Run inference on *imageFile.
// For multiple images, session.Run() can be called in a loop (and
- // concurrently). Furthermore, images can be batched together since the
- // model accepts batches of image data as input.
+ // concurrently). Alternatively, images can be batched since the model
+ // accepts batches of image data as input.
tensor, err := makeTensorFromImage(*imagefile)
if err != nil {
log.Fatal(err)
@@ -125,8 +125,8 @@ func printBestLabel(probabilities []float32, labelsFile string) {
bestIdx = i
}
}
- // Found a best match, now read the string from the labelsFile where
- // there is one line per label.
+ // Found the best match. Read the string from labelsFile, which
+ // contains one line per label.
file, err := os.Open(labelsFile)
if err != nil {
log.Fatal(err)
@@ -143,34 +143,14 @@ func printBestLabel(probabilities []float32, labelsFile string) {
fmt.Printf("BEST MATCH: (%2.0f%% likely) %s\n", probabilities[bestIdx]*100.0, labels[bestIdx])
}
-// Conver the image in filename to a Tensor suitable as input to the Inception model.
+// Convert the image in filename to a Tensor suitable as input to the Inception model.
func makeTensorFromImage(filename string) (*tf.Tensor, error) {
- // Load the pixels from the file
- file, err := os.Open(filename)
- if err != nil {
- return nil, err
- }
- img, _, err := image.Decode(file)
- file.Close()
+ bytes, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
- // Represent the image as [H][W][B,G,R]byte
- contents := make([][][3]byte, img.Bounds().Size().Y)
- for y := 0; y < len(contents); y++ {
- contents[y] = make([][3]byte, img.Bounds().Size().X)
- for x := 0; x < len(contents[y]); x++ {
- px := x + img.Bounds().Min.X
- py := y + img.Bounds().Min.Y
- r, g, b, _ := img.At(px, py).RGBA()
- // image.Image uses 16-bits for each color.
- // We want 8-bits.
- contents[y][x][0] = byte(b >> 8)
- contents[y][x][1] = byte(g >> 8)
- contents[y][x][2] = byte(r >> 8)
- }
- }
- tensor, err := tf.NewTensor(contents)
+ // DecodeJpeg uses a scalar String-valued tensor as input.
+ tensor, err := tf.NewTensor(string(bytes))
if err != nil {
return nil, err
}
@@ -199,12 +179,9 @@ func makeTensorFromImage(filename string) (*tf.Tensor, error) {
// specific normalized format (a particular image size, shape of the input tensor,
// normalized pixel values etc.).
//
-// This function constructs a graph of TensorFlow operations which takes as input
-// the raw pixel values of an image in the form of a Tensor of shape [Height, Width, 3]
-// and returns a tensor suitable for input to the inception model.
-//
-// T[y][x] is the (Blue, Green, Red) values of the pixel at position (x, y) in the image,
-// with each color value represented as a single byte.
+// This function constructs a graph of TensorFlow operations which takes as
+// input a JPEG-encoded string and returns a tensor suitable as input to the
+// inception model.
func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output, err error) {
// Some constants specific to the pre-trained model at:
// https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
@@ -212,27 +189,25 @@ func constructGraphToNormalizeImage() (graph *tf.Graph, input, output tf.Output,
// - The model was trained after with images scaled to 224x224 pixels.
// - The colors, represented as R, G, B in 1-byte each were converted to
// float using (value - Mean)/Scale.
- //
- // If using a different pre-trained model, the values will have to be adjusted.
const (
H, W = 224, 224
Mean = float32(117)
Scale = float32(1)
)
- // - input is a 3D tensor of shape [Height, Width, Colors=3], where
- // each pixel is represented as a triplet of 1-byte colors
- // - ResizeBilinear (and the inception model) takes a 4D tensor of shape
+ // - input is a String-Tensor, where the string the JPEG-encoded image.
+ // - The inception model takes a 4D tensor of shape
// [BatchSize, Height, Width, Colors=3], where each pixel is
// represented as a triplet of floats
// - Apply normalization on each pixel and use ExpandDims to make
// this single image be a "batch" of size 1 for ResizeBilinear.
s := op.NewScope()
- input = op.Placeholder(s, tf.Uint8)
+ input = op.Placeholder(s, tf.String)
output = op.Div(s,
op.Sub(s,
op.ResizeBilinear(s,
op.ExpandDims(s,
- op.Cast(s, input, tf.Float),
+ op.Cast(s,
+ op.DecodeJpeg(s, input, op.DecodeJpegChannels(3)), tf.Float),
op.Const(s.SubScope("make_batch"), int32(0))),
op.Const(s.SubScope("size"), []int32{H, W})),
op.Const(s.SubScope("mean"), Mean)),
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go
index 5d5aa26992..75c111e957 100644
--- a/tensorflow/go/genop/internal/genop.go
+++ b/tensorflow/go/genop/internal/genop.go
@@ -387,7 +387,7 @@ func goType(tfType string) (string, error) {
var gotype string
switch tfType {
case "int":
- gotype = "int"
+ gotype = "int64"
case "float":
gotype = "float32"
case "bool":
diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go
index b3bcd9db05..c3057e9119 100644
--- a/tensorflow/go/genop/internal/genop_test.go
+++ b/tensorflow/go/genop/internal/genop_test.go
@@ -188,7 +188,7 @@ type DecodeJpegAttr func(optionalAttr)
//
// value: Number of color channels for the decoded image.
// If not specified, defaults to i:0
-func DecodeJpegChannels(value int) DecodeJpegAttr {
+func DecodeJpegChannels(value int64) DecodeJpegAttr {
return func(m optionalAttr) {
m["channels"] = value
}
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 0d3660995b..14ecca402b 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -119,6 +119,46 @@ func TestSessionRunConcat(t *testing.T) {
}
}
+func TestSessionWithStringTensors(t *testing.T) {
+ // Construct the graph:
+ // AsString(StringToHashBucketFast("PleaseHashMe")) Will be much
+ // prettier if using the ops package, but in this package graphs are
+ // constructed from first principles.
+ var (
+ g = NewGraph()
+ feed, _ = Const(g, "input", "PleaseHashMe")
+ hash, _ = g.AddOperation(OpSpec{
+ Type: "StringToHashBucketFast",
+ Input: []Input{feed},
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(1 << 32),
+ },
+ })
+ str, _ = g.AddOperation(OpSpec{
+ Type: "AsString",
+ Input: []Input{hash.Output(0)},
+ })
+ )
+ s, err := NewSession(g, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ output, err := s.Run(nil, []Output{str.Output(0)}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(output) != 1 {
+ t.Fatal(len(output))
+ }
+ got, ok := output[0].Value().(string)
+ if !ok {
+ t.Fatalf("Got %T, wanted string", output[0].Value())
+ }
+ if want := "1027741475"; got != want {
+ t.Fatalf("Got %q, want %q", got, want)
+ }
+}
+
func TestConcurrency(t *testing.T) {
tensor, err := NewTensor(int64(1))
if err != nil {
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 12ec9c85fb..f755e9d4f8 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -14,6 +14,7 @@
package tensorflow
+// #include <stdlib.h>
// #include <string.h>
// #include "tensorflow/c/c_api.h"
import "C"
@@ -70,11 +71,13 @@ func NewTensor(value interface{}) (*Tensor, error) {
if err != nil {
return nil, err
}
+ nflattened := numElements(shape)
+ nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
if dataType == String {
- // TODO(ashankar): Handle this
- return nil, fmt.Errorf("String Tensors are not currently supported")
+ // TF_STRING tensors are encoded as an array of 8-byte offsets
+ // followed by string data. See c_api.h.
+ nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
}
- nbytes := byteSizeOf(dataType, shape)
var shapePtr *C.int64_t
if len(shape) > 0 {
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
@@ -86,11 +89,21 @@ func NewTensor(value interface{}) (*Tensor, error) {
runtime.SetFinalizer(t, (*Tensor).finalize)
raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)])
- if err := encodeTensor(buf, val); err != nil {
- return nil, err
- }
- if uintptr(buf.Len()) != nbytes {
- return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version())
+ if dataType != String {
+ if err := encodeTensor(buf, val); err != nil {
+ return nil, err
+ }
+ if uintptr(buf.Len()) != nbytes {
+ return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
+ }
+ } else {
+ e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
+ if e.encode(reflect.ValueOf(value)); err != nil {
+ return nil, err
+ }
+ if int64(buf.Len()) != nflattened*8 {
+ return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
+ }
}
return t, nil
}
@@ -126,13 +139,19 @@ func (t *Tensor) Shape() []int64 { return t.shape }
// Tensor(int64, 0): int64
// Tensor(float64, 3): [][][]float64
func (t *Tensor) Value() interface{} {
- typ, err := typeOf(t.DataType(), t.Shape())
- if err != nil {
- panic(err)
- }
+ typ := typeOf(t.DataType(), t.Shape())
val := reflect.New(typ)
- if err := decodeTensor(bytes.NewReader(tensorData(t.c)), t.Shape(), typ, val); err != nil {
- panic(err)
+ raw := tensorData(t.c)
+ if t.DataType() != String {
+ if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil {
+ panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err))
+ }
+ } else {
+ nflattened := numElements(t.Shape())
+ d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()}
+ if err := d.decode(val, t.Shape()); err != nil {
+ panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err))
+ }
}
return reflect.Indirect(val).Interface()
}
@@ -194,7 +213,7 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
}
// typeOf converts from a DataType and Shape to the equivalent Go type.
-func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
+func typeOf(dt DataType, shape []int64) reflect.Type {
var ret reflect.Type
for _, t := range types {
if dt == DataType(t.dataType) {
@@ -203,32 +222,39 @@ func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
}
}
if ret == nil {
- return nil, fmt.Errorf("DataType %v unsupported", dt)
+ panic(bug("DataType %v is not supported", dt))
}
for _ = range shape {
ret = reflect.SliceOf(ret)
}
- return ret, nil
+ return ret
}
-// byteSizeOf returns the size (in bytes) of the raw encoding of a tensor with
-// the given shape and DataType. Only meant for non-String tensors.
-func byteSizeOf(dt DataType, shape []int64) uintptr {
- var size uintptr
- for _, t := range types {
- if DataType(t.dataType) == dt {
- size = t.typ.Size()
- break
- }
- }
+func numElements(shape []int64) int64 {
+ n := int64(1)
for _, d := range shape {
- size *= uintptr(d)
+ n *= d
+ }
+ return n
+}
+
+// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
+// val MUST be a string, or a container (array/slice etc.) of strings.
+func byteSizeOfEncodedStrings(val interface{}) uintptr {
+ if s, ok := val.(string); ok {
+ return uintptr(C.TF_StringEncodedSize(C.size_t(len(s))))
+ }
+ // Otherwise must be an array or slice.
+ var size uintptr
+ v := reflect.ValueOf(val)
+ for i := 0; i < v.Len(); i++ {
+ size += byteSizeOfEncodedStrings(v.Index(i).Interface())
}
return size
}
// encodeTensor writes v to the specified buffer using the format specified in
-// c_api.h.
+// c_api.h. Use stringEncoder for String tensors.
func encodeTensor(w io.Writer, v reflect.Value) error {
switch v.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
@@ -262,7 +288,7 @@ func encodeTensor(w io.Writer, v reflect.Value) error {
}
// decodeTensor decodes the Tensor from the buffer to ptr using the format
-// specified in c_api.h
+// specified in c_api.h. Use stringDecoder for String tensors.
func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
@@ -285,6 +311,80 @@ func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Valu
return nil
}
+type stringEncoder struct {
+ offsets io.Writer
+ data []byte
+ offset uint64
+ status *status
+}
+
+func (e *stringEncoder) encode(v reflect.Value) error {
+ if v.Kind() == reflect.String {
+ if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
+ return err
+ }
+ var (
+ s = v.Interface().(string)
+ src = C.CString(s)
+ srcLen = C.size_t(len(s))
+ dst = (*C.char)(unsafe.Pointer(&e.data[e.offset]))
+ dstLen = C.size_t(uint64(len(e.data)) - e.offset)
+ )
+ e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c))
+ C.free(unsafe.Pointer(src))
+ return e.status.Err()
+ }
+ for i := 0; i < v.Len(); i++ {
+ if err := e.encode(v.Index(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type stringDecoder struct {
+ offsets io.Reader
+ data []byte
+ status *status
+}
+
+func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
+ if len(shape) == 0 {
+ var offset uint64
+ if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
+ return err
+ }
+ var (
+ src = (*C.char)(unsafe.Pointer(&d.data[offset]))
+ srcLen = C.size_t(len(d.data)) - C.size_t(offset)
+ dst *C.char
+ dstLen C.size_t
+ )
+ if offset > uint64(len(d.data)) {
+ return fmt.Errorf("invalid offsets in String Tensor")
+ }
+ C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
+ if err := d.status.Err(); err != nil {
+ return err
+ }
+ s := ptr.Interface().(*string)
+ *s = C.GoStringN(dst, C.int(dstLen))
+ return nil
+ }
+ val := reflect.Indirect(ptr)
+ val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
+ for i := 0; i < val.Len(); i++ {
+ if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func bug(format string, args ...interface{}) error {
+ return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
+}
+
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index d5f3f74bfd..073da0cc6e 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -35,8 +35,11 @@ func TestNewTensor(t *testing.T) {
{nil, float64(5)},
{nil, complex(float32(5), float32(6))},
{nil, complex(float64(5), float64(6))},
+ {nil, "a string"},
{[]int64{1}, []float64{1}},
{[]int64{1}, [1]float64{1}},
+ {[]int64{2}, []string{"string", "slice"}},
+ {[]int64{2}, [2]string{"string", "array"}},
{[]int64{3, 2}, [][]float64{{1, 2}, {3, 4}, {5, 6}}},
{[]int64{2, 3}, [2][3]float64{{1, 2, 3}, {3, 4, 6}}},
{[]int64{4, 3, 2}, [][][]float64{
@@ -46,6 +49,11 @@ func TestNewTensor(t *testing.T) {
{{-6, -7}, {-8, -9}, {-10, -11}},
}},
{[]int64{2, 0}, [][]int64{{}, {}}},
+ {[]int64{2, 2}, [][]string{{"row0col0", "row0,col1"}, {"row1col0", "row1,col1"}}},
+ {[]int64{2, 3}, [2][3]string{
+ {"row0col0", "row0,col1", "row0,col2"},
+ {"row1col0", "row1,col1", "row1,col2"},
+ }},
}
var errorTests = []interface{}{