aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index 36a74c0081..1326a95278 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -101,7 +101,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
}
} else {
- e := stringEncoder{offsets: buf, data: raw[nflattened*8 : len(raw)], status: newStatus()}
+ e := stringEncoder{offsets: buf, data: raw[nflattened*8:], status: newStatus()}
if err := e.encode(reflect.ValueOf(value), shape); err != nil {
return nil, err
}
@@ -207,6 +207,9 @@ func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
func tensorData(c *C.TF_Tensor) []byte {
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
cbytes := C.TF_TensorData(c)
+ if cbytes == nil {
+ return nil
+ }
length := int(C.TF_TensorByteSize(c))
slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length]
return slice
@@ -310,7 +313,7 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if err := w.WriteByte(b); err != nil {
return err
}
- case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
@@ -349,7 +352,7 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return err
}
ptr.Elem().SetBool(b == 1)
- case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}