aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-11-21 11:24:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-21 11:44:14 -0800
commitc80be35454a418c18b3fd57614bfcb5265274c33 (patch)
tree975b50379b742f7913eb6d2c14ab8fdfc552d22d /tensorflow/go/tensor.go
parent074acf38d83bf4be1e3fe2bb813d4bf32b97c2ac (diff)
Go: Support for String tensors.
And use this support to simplify the Inception example as it can use the DecodeJpeg op. Also fixed a bug in generated op functions - A TensorFlow "int" is a Go "int64". Another step in #10 Change: 139809489
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go160
1 files changed, 130 insertions, 30 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 12ec9c85fb..f755e9d4f8 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -14,6 +14,7 @@
package tensorflow
+// #include <stdlib.h>
// #include <string.h>
// #include "tensorflow/c/c_api.h"
import "C"
@@ -70,11 +71,13 @@ func NewTensor(value interface{}) (*Tensor, error) {
if err != nil {
return nil, err
}
+ nflattened := numElements(shape)
+ nbytes := typeOf(dataType, nil).Size() * uintptr(nflattened)
if dataType == String {
- // TODO(ashankar): Handle this
- return nil, fmt.Errorf("String Tensors are not currently supported")
+ // TF_STRING tensors are encoded as an array of 8-byte offsets
+ // followed by string data. See c_api.h.
+ nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
}
- nbytes := byteSizeOf(dataType, shape)
var shapePtr *C.int64_t
if len(shape) > 0 {
shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
@@ -86,11 +89,21 @@ func NewTensor(value interface{}) (*Tensor, error) {
runtime.SetFinalizer(t, (*Tensor).finalize)
raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)])
- if err := encodeTensor(buf, val); err != nil {
- return nil, err
- }
- if uintptr(buf.Len()) != nbytes {
- return nil, fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v bytes, version %v", dataType, shape, nbytes, buf.Len(), Version())
+ if dataType != String {
+ if err := encodeTensor(buf, val); err != nil {
+ return nil, err
+ }
+ if uintptr(buf.Len()) != nbytes {
+ return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
+ }
+ } else {
+ e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
+ if e.encode(reflect.ValueOf(value)); err != nil {
+ return nil, err
+ }
+ if int64(buf.Len()) != nflattened*8 {
+ return nil, bug("invalid offset encoding for TF_STRING tensor with shape %v (got %v, want %v)", shape, buf.Len(), nflattened*8)
+ }
}
return t, nil
}
@@ -126,13 +139,19 @@ func (t *Tensor) Shape() []int64 { return t.shape }
// Tensor(int64, 0): int64
// Tensor(float64, 3): [][][]float64
func (t *Tensor) Value() interface{} {
- typ, err := typeOf(t.DataType(), t.Shape())
- if err != nil {
- panic(err)
- }
+ typ := typeOf(t.DataType(), t.Shape())
val := reflect.New(typ)
- if err := decodeTensor(bytes.NewReader(tensorData(t.c)), t.Shape(), typ, val); err != nil {
- panic(err)
+ raw := tensorData(t.c)
+ if t.DataType() != String {
+ if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil {
+ panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err))
+ }
+ } else {
+ nflattened := numElements(t.Shape())
+ d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()}
+ if err := d.decode(val, t.Shape()); err != nil {
+ panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err))
+ }
}
return reflect.Indirect(val).Interface()
}
@@ -194,7 +213,7 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
}
// typeOf converts from a DataType and Shape to the equivalent Go type.
-func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
+func typeOf(dt DataType, shape []int64) reflect.Type {
var ret reflect.Type
for _, t := range types {
if dt == DataType(t.dataType) {
@@ -203,32 +222,39 @@ func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
}
}
if ret == nil {
- return nil, fmt.Errorf("DataType %v unsupported", dt)
+ panic(bug("DataType %v is not supported", dt))
}
for _ = range shape {
ret = reflect.SliceOf(ret)
}
- return ret, nil
+ return ret
}
-// byteSizeOf returns the size (in bytes) of the raw encoding of a tensor with
-// the given shape and DataType. Only meant for non-String tensors.
-func byteSizeOf(dt DataType, shape []int64) uintptr {
- var size uintptr
- for _, t := range types {
- if DataType(t.dataType) == dt {
- size = t.typ.Size()
- break
- }
- }
+func numElements(shape []int64) int64 {
+ n := int64(1)
for _, d := range shape {
- size *= uintptr(d)
+ n *= d
+ }
+ return n
+}
+
+// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
+// val MUST be a string, or a container (array/slice etc.) of strings.
+func byteSizeOfEncodedStrings(val interface{}) uintptr {
+ if s, ok := val.(string); ok {
+ return uintptr(C.TF_StringEncodedSize(C.size_t(len(s))))
+ }
+ // Otherwise must be an array or slice.
+ var size uintptr
+ v := reflect.ValueOf(val)
+ for i := 0; i < v.Len(); i++ {
+ size += byteSizeOfEncodedStrings(v.Index(i).Interface())
}
return size
}
// encodeTensor writes v to the specified buffer using the format specified in
-// c_api.h.
+// c_api.h. Use stringEncoder for String tensors.
func encodeTensor(w io.Writer, v reflect.Value) error {
switch v.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
@@ -262,7 +288,7 @@ func encodeTensor(w io.Writer, v reflect.Value) error {
}
// decodeTensor decodes the Tensor from the buffer to ptr using the format
-// specified in c_api.h
+// specified in c_api.h. Use stringDecoder for String tensors.
func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
@@ -285,6 +311,80 @@ func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Valu
return nil
}
+type stringEncoder struct {
+ offsets io.Writer
+ data []byte
+ offset uint64
+ status *status
+}
+
+func (e *stringEncoder) encode(v reflect.Value) error {
+ if v.Kind() == reflect.String {
+ if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
+ return err
+ }
+ var (
+ s = v.Interface().(string)
+ src = C.CString(s)
+ srcLen = C.size_t(len(s))
+ dst = (*C.char)(unsafe.Pointer(&e.data[e.offset]))
+ dstLen = C.size_t(uint64(len(e.data)) - e.offset)
+ )
+ e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c))
+ C.free(unsafe.Pointer(src))
+ return e.status.Err()
+ }
+ for i := 0; i < v.Len(); i++ {
+ if err := e.encode(v.Index(i)); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type stringDecoder struct {
+ offsets io.Reader
+ data []byte
+ status *status
+}
+
+func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
+ if len(shape) == 0 {
+ var offset uint64
+ if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
+ return err
+ }
+ var (
+ src = (*C.char)(unsafe.Pointer(&d.data[offset]))
+ srcLen = C.size_t(len(d.data)) - C.size_t(offset)
+ dst *C.char
+ dstLen C.size_t
+ )
+ if offset > uint64(len(d.data)) {
+ return fmt.Errorf("invalid offsets in String Tensor")
+ }
+ C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
+ if err := d.status.Err(); err != nil {
+ return err
+ }
+ s := ptr.Interface().(*string)
+ *s = C.GoStringN(dst, C.int(dstLen))
+ return nil
+ }
+ val := reflect.Indirect(ptr)
+ val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
+ for i := 0; i < val.Len(); i++ {
+ if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func bug(format string, args ...interface{}) error {
+ return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
+}
+
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.