diff options
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r-- | tensorflow/go/tensor.go | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 36a74c0081..1326a95278 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -101,7 +101,7 @@ func NewTensor(value interface{}) (*Tensor, error) { return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) } } else { - e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()} + e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()} if err := e.encode(reflect.ValueOf(value), shape); err != nil { return nil, err } @@ -207,6 +207,9 @@ func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { func tensorData(c *C.TF_Tensor) []byte { // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices cbytes := C.TF_TensorData(c) + if cbytes == nil { + return nil + } length := int(C.TF_TensorByteSize(c)) slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length] return slice @@ -310,7 +313,7 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { if err := w.WriteByte(b); err != nil { return err } - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { return err } @@ -349,7 +352,7 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. return err } ptr.Elem().SetBool(b == 1) - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { return err } |