diff options
author | 2016-11-21 11:24:53 -0800 | |
---|---|---|
committer | 2016-11-21 11:44:14 -0800 | |
commit | c80be35454a418c18b3fd57614bfcb5265274c33 (patch) | |
tree | 975b50379b742f7913eb6d2c14ab8fdfc552d22d /tensorflow/go/tensor.go | |
parent | 074acf38d83bf4be1e3fe2bb813d4bf32b97c2ac (diff) |
Go: Support for String tensors.
And use this support to simplify the Inception example as it can use the
DecodeJpeg op.
Also fixed a bug in generated op functions - A TensorFlow "int"
is a Go "int64".
Another step in #10
Change: 139809489
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r-- | tensorflow/go/tensor.go | 160 |
1 files changed, 130 insertions, 30 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 12ec9c85fb..f755e9d4f8 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -14,6 +14,7 @@ package tensorflow +// #include <stdlib.h> // #include <string.h> // #include "tensorflow/c/c_api.h" import "C" @@ -70,11 +71,13 @@ func NewTensor(value interface{}) (*Tensor, error) { if err != nil { return nil, err } + nflattened := numElements(shape) + nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened) if dataType == String { - // TODO(ashankar): Handle this - return nil, fmt.Errorf("String Tensors are not currently supported") + // TF_STRING tensors are encoded as an array of 8-byte offsets + // followed by string data. See c_api.h. + nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value) } - nbytes := byteSizeOf(dataType, shape) var shapePtr *C.int64_t if len(shape) > 0 { shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) @@ -86,11 +89,21 @@ func NewTensor(value interface{}) (*Tensor, error) { runtime.SetFinalizer(t, (*Tensor).finalize) raw := tensorData(t.c) buf := bytes.NewBuffer(raw[:0:len(raw)]) - if err := encodeTensor(buf, val); err != nil { - return nil, err - } - if uintptr(buf.Len()) != nbytes { - return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version()) + if dataType != String { + if err := encodeTensor(buf, val); err != nil { + return nil, err + } + if uintptr(buf.Len()) != nbytes { + return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) + } + } else { + e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()} + if e.encode(reflect.ValueOf(value)); err != nil { + return nil, err + } + if int64(buf.Len()) != nflattened*8 { + return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8) + } } return t, nil } @@ -126,13 +139,19 @@ func (t *Tensor) Shape() []int64 { return t.shape } // Tensor(int64, 0): int64 // Tensor(float64, 3): [][][]float64 func (t *Tensor) Value() interface{} { - typ, err := typeOf(t.DataType(), t.Shape()) - if err != nil { - panic(err) - } + typ := typeOf(t.DataType(), t.Shape()) val := reflect.New(typ) - if err := decodeTensor(bytes.NewReader(tensorData(t.c)), t.Shape(), typ, val); err != nil { - panic(err) + raw := tensorData(t.c) + if t.DataType() != String { + if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil { + panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err)) + } + } else { + nflattened := numElements(t.Shape()) + d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()} + if err := d.decode(val, t.Shape()); err != nil { + panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err)) + } } return reflect.Indirect(val).Interface() } @@ -194,7 +213,7 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro } // typeOf converts from a DataType and Shape to the equivalent Go type. -func typeOf(dt DataType, shape []int64) (reflect.Type, error) { +func typeOf(dt DataType, shape []int64) reflect.Type { var ret reflect.Type for _, t := range types { if dt == DataType(t.dataType) { @@ -203,32 +222,39 @@ func typeOf(dt DataType, shape []int64) (reflect.Type, error) { } } if ret == nil { - return nil, fmt.Errorf("DataType %v unsupported", dt) + panic(bug("DataType %v is not supported", dt)) } for _ = range shape { ret = reflect.SliceOf(ret) } - return ret, nil + return ret } -// byteSizeOf returns the size (in bytes) of the raw encoding of a tensor with -// the given shape and DataType. Only meant for non-String tensors. -func byteSizeOf(dt DataType, shape []int64) uintptr { - var size uintptr - for _, t := range types { - if DataType(t.dataType) == dt { - size = t.typ.Size() - break - } - } +func numElements(shape []int64) int64 { + n := int64(1) for _, d := range shape { - size *= uintptr(d) + n *= d + } + return n +} + +// byteSizeOfEncodedStrings returns the size of the encoded strings in val. +// val MUST be a string, or a container (array/slice etc.) of strings. +func byteSizeOfEncodedStrings(val interface{}) uintptr { + if s, ok := val.(string); ok { + return uintptr(C.TF_StringEncodedSize(C.size_t(len(s)))) + } + // Otherwise must be an array or slice. + var size uintptr + v := reflect.ValueOf(val) + for i := 0; i < v.Len(); i++ { + size += byteSizeOfEncodedStrings(v.Index(i).Interface()) } return size } // encodeTensor writes v to the specified buffer using the format specified in -// c_api.h. +// c_api.h. Use stringEncoder for String tensors. func encodeTensor(w io.Writer, v reflect.Value) error { switch v.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: @@ -262,7 +288,7 @@ func encodeTensor(w io.Writer, v reflect.Value) error { } // decodeTensor decodes the Tensor from the buffer to ptr using the format -// specified in c_api.h +// specified in c_api.h. Use stringDecoder for String tensors. func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { switch typ.Kind() { case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: @@ -285,6 +311,80 @@ func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Valu return nil } +type stringEncoder struct { + offsets io.Writer + data []byte + offset uint64 + status *status +} + +func (e *stringEncoder) encode(v reflect.Value) error { + if v.Kind() == reflect.String { + if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil { + return err + } + var ( + s = v.Interface().(string) + src = C.CString(s) + srcLen = C.size_t(len(s)) + dst = (*C.char)(unsafe.Pointer(&e.data[e.offset])) + dstLen = C.size_t(uint64(len(e.data)) - e.offset) + ) + e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c)) + C.free(unsafe.Pointer(src)) + return e.status.Err() + } + for i := 0; i < v.Len(); i++ { + if err := e.encode(v.Index(i)); err != nil { + return err + } + } + return nil +} + +type stringDecoder struct { + offsets io.Reader + data []byte + status *status +} + +func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error { + if len(shape) == 0 { + var offset uint64 + if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil { + return err + } + var ( + src = (*C.char)(unsafe.Pointer(&d.data[offset])) + srcLen = C.size_t(len(d.data)) - C.size_t(offset) + dst *C.char + dstLen C.size_t + ) + if offset > uint64(len(d.data)) { + return fmt.Errorf("invalid offsets in String Tensor") + } + C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c) + if err := d.status.Err(); err != nil { + return err + } + s := ptr.Interface().(*string) + *s = C.GoStringN(dst, C.int(dstLen)) + return nil + } + val := reflect.Indirect(ptr) + val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0]))) + for i := 0; i < val.Len(); i++ { + if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil { + return err + } + } + return nil +} + +func bug(format string, args ...interface{}) error { + return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) +} + // nativeEndian is the byte order for the local platform. Used to send back and // forth Tensors with the C API. We test for endianness at runtime because // some architectures can be booted into different endian modes. |