aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-11-15 00:15:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-15 00:24:28 -0800
commit1f0c5119a0230c5160d45496175b9256f097e144 (patch)
tree107fd56ab6f05ce743d1867893c4dd75bd7a98f4 /tensorflow/go/tensor.go
parent1791aeef881915d75c1d077ae77a205e1bc2d51c (diff)
C API: Do not take ownership of the TF_Tensors
Prior to this change, TF_*Run, TF_SetAttrTensor and TF_SetAttrTensorList took ownership of the TF_Tensor*s of the feeds. This can make performance client languages bothersome when the same Tensor is repeatedly fed into multiple session executions as the memory for the feed tensor would need to be re-allocated and filled in every time. With this change, these functions no longer take ownership of the TF_Tensor*. The changes to the Go API implementation reflect the claimed benefits. (Another step towards #10) Change: 139169388
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go102
1 files changed, 45 insertions, 57 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index a12a732258..9b93590b6a 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -22,7 +22,9 @@ import (
"bytes"
"encoding/binary"
"fmt"
+ "io"
"reflect"
+ "runtime"
"unsafe"
)
@@ -55,16 +57,7 @@ const (
// Tensor holds a multi-dimensional array of elements of a single data type.
type Tensor struct {
- // We create TF_Tensor on demand rather than keep a handle to C.TF_Tensor
- // because many functions, such as Session.Run() and Operations take
- // ownership of the C.TF_Tensor. Translating on-demand provides for a safe
- // API.
- //
- // A memcpy is required because cgo rules prohibit us from maintaining
- // a pointer to Go memory.
- // call: https://golang.org/cmd/cgo/
- buf *bytes.Buffer
- dt DataType
+ c *C.TF_Tensor
shape []int64
}
@@ -77,37 +70,50 @@ func NewTensor(value interface{}) (*Tensor, error) {
if err != nil {
return nil, err
}
- t := &Tensor{buf: bytes.NewBuffer(nil), dt: dataType, shape: make([]int64, dims)}
- if err = encodeTensor(t.buf, t.shape, val); err != nil {
+ // TODO(ashankar): Remove the bytes.Buffer and endcode directly into
+ // C-memory, avoiding the memcpy and cutting down memory usage in half.
+ shape := make([]int64, dims)
+ buf := new(bytes.Buffer)
+ if err := encodeTensor(buf, shape, val); err != nil {
return nil, err
}
+ var shapePtr *C.int64_t
+ if len(shape) > 0 {
+ shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
+ }
+ t := &Tensor{
+ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(buf.Len())),
+ shape: shape,
+ }
+ runtime.SetFinalizer(t, (*Tensor).finalize)
+ if buf.Len() > 0 {
+ slice := buf.Bytes() // https://github.com/golang/go/issues/14210
+ C.memcpy(C.TF_TensorData(t.c), unsafe.Pointer(&slice[0]), C.size_t(buf.Len()))
+ }
return t, nil
}
-// newTensorFromC converts from a C.TF_Tensor to a Tensor.
-func newTensorFromC(ct *C.TF_Tensor) *Tensor {
- t := &Tensor{dt: DataType(C.TF_TensorType(ct))}
- numDims := int(C.TF_NumDims(ct))
- for i := 0; i < numDims; i++ {
- t.shape = append(t.shape, int64(C.TF_Dim(ct, C.int(i))))
+// newTensorFromC takes ownership of c and returns the owning Tensor.
+func newTensorFromC(c *C.TF_Tensor) *Tensor {
+ var shape []int64
+ if ndims := int(C.TF_NumDims(c)); ndims > 0 {
+ shape = make([]int64, ndims)
}
- b := make([]byte, int(C.TF_TensorByteSize(ct)))
- if len(b) > 0 {
- C.memcpy(unsafe.Pointer(&b[0]), C.TF_TensorData(ct), C.size_t(len(b)))
+ for i := range shape {
+ shape[i] = int64(C.TF_Dim(c, C.int(i)))
}
- t.buf = bytes.NewBuffer(b)
+ t := &Tensor{c: c, shape: shape}
+ runtime.SetFinalizer(t, (*Tensor).finalize)
return t
}
+func (t *Tensor) finalize() { C.TF_DeleteTensor(t.c) }
+
// DataType returns the scalar datatype of the Tensor.
-func (t *Tensor) DataType() DataType {
- return t.dt
-}
+func (t *Tensor) DataType() DataType { return DataType(C.TF_TensorType(t.c)) }
// Shape returns the shape of the Tensor.
-func (t *Tensor) Shape() []int64 {
- return t.shape
-}
+func (t *Tensor) Shape() []int64 { return t.shape }
// Value converts the Tensor to a Go value. For now, not all Tensor types are
// supported, and this function may panic if it encounters an unsupported
@@ -123,34 +129,16 @@ func (t *Tensor) Value() interface{} {
panic(err)
}
val := reflect.New(typ)
- if err := decodeTensor(t.buf, t.Shape(), typ, val); err != nil {
+ // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
+ cbytes := C.TF_TensorData(t.c)
+ length := int(C.TF_TensorByteSize(t.c))
+ slice := (*[1 << 30]byte)(unsafe.Pointer(cbytes))[:length:length]
+ if err := decodeTensor(bytes.NewReader(slice), t.Shape(), typ, val); err != nil {
panic(err)
}
return reflect.Indirect(val).Interface()
}
-// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of
-// the *C.TF_Tensor, either by passing ownership to the C API or explicitly
-// calling C.TF_DeleteTensor() on it.
-func (t *Tensor) c() *C.TF_Tensor {
- var shapePtr *C.int64_t
- if len(t.shape) > 0 {
- shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0]))
- }
- tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len()))
- if t.buf.Len() > 0 {
- slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210
- C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len()))
- }
- return tensor
-}
-
-// deleteCTensor only exists to delete C.TF_Tensors in tests. go test doesn't
-// support cgo.
-func deleteCTensor(ct *C.TF_Tensor) {
- C.TF_DeleteTensor(ct)
-}
-
var types = []struct {
typ reflect.Type
dataType C.TF_DataType
@@ -206,10 +194,10 @@ func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
// encodeTensor writes v to the specified buffer using the format specified in
// c_api.h
-func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error {
+func encodeTensor(w io.Writer, shape []int64, v reflect.Value) error {
switch v.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
- if err := binary.Write(buf, nativeEndian, v.Interface()); err != nil {
+ if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
@@ -227,7 +215,7 @@ func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error {
shape[0] = int64(v.Len())
for i := 0; i < v.Len(); i++ {
- err := encodeTensor(buf, shape[1:], v.Index(i))
+ err := encodeTensor(w, shape[1:], v.Index(i))
if err != nil {
return err
}
@@ -241,10 +229,10 @@ func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error {
// decodeTensor decodes the Tensor from the buffer to ptr using the format
// specified in c_api.h
-func decodeTensor(buf *bytes.Buffer, shape []int64, typ reflect.Type, ptr reflect.Value) error {
+func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
- if err := binary.Read(buf, nativeEndian, ptr.Interface()); err != nil {
+ if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
@@ -252,7 +240,7 @@ func decodeTensor(buf *bytes.Buffer, shape []int64, typ reflect.Type, ptr reflec
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
for i := 0; i < val.Len(); i++ {
- if err := decodeTensor(buf, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
+ if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
return err
}
}