aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-08-23 09:01:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 10:04:53 -0700
commit783c52edeb3c676937dbb97ed0d40958015050d6 (patch)
tree80c74954f68dad26a6e76a1c0edcb979d4d1804c /tensorflow/go/tensor.go
parent096069687c52e16eaa18c1db6e7bbf2737639257 (diff)
Initial version of the Go API. The API is subject to change.
Remaining work to do: - Generated ops. - Generated protocol buffers. - A few calls requiring protocol buffers aren't in this change. Change: 131066649
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go259
1 files changed, 259 insertions, 0 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
new file mode 100644
index 0000000000..e364e80f86
--- /dev/null
+++ b/tensorflow/go/tensor.go
@@ -0,0 +1,259 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include <string.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// DataType holds the type for a scalar value. E.g., one slot in a tensor.
+// The values here are identical to corresponding values in types.proto.
+type DataType C.TF_DataType
+
+// Tensor holds a multi-dimensional array of elements of a single data type.
+type Tensor struct {
+ // We create TF_Tensor on demand rather than keep a handle to C.TF_Tensor
+ // because many functions, such as Session.Run() and Operations take
+ // ownership of the C.TF_Tensor. Translating on-demand provides for a safe
+ // API.
+ //
+ // A memcpy is required because cgo rules prohibit us from maintaining
+ // a pointer to Go memory.
+ // call: https://golang.org/cmd/cgo/
+ buf *bytes.Buffer
+ dt DataType
+ shape []int64
+}
+
+// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
+// slices, and arrays. Every element of a slice must have the same length so
+// that the resulting Tensor has a valid shape.
+func NewTensor(value interface{}) (*Tensor, error) {
+ val := reflect.ValueOf(value)
+ rank, dataType, err := rankAndDataTypeOf(val.Type())
+ if err != nil {
+ return nil, err
+ }
+ t := &Tensor{buf: bytes.NewBuffer(nil), dt: dataType, shape: make([]int64, rank)}
+ if err = encodeTensor(t.buf, t.shape, val); err != nil {
+ return nil, err
+ }
+ return t, nil
+}
+
+// newTensorFromC converts from a C.TF_Tensor to a Tensor.
+func newTensorFromC(ct *C.TF_Tensor) *Tensor {
+ t := &Tensor{dt: DataType(C.TF_TensorType(ct))}
+ numDims := int(C.TF_NumDims(ct))
+ for i := 0; i < numDims; i++ {
+ t.shape = append(t.shape, int64(C.TF_Dim(ct, C.int(i))))
+ }
+ b := make([]byte, int(C.TF_TensorByteSize(ct)))
+ if len(b) > 0 {
+ C.memcpy(unsafe.Pointer(&b[0]), C.TF_TensorData(ct), C.size_t(len(b)))
+ }
+ t.buf = bytes.NewBuffer(b)
+ return t
+}
+
+// DataType returns the scalar datatype of the Tensor.
+func (t *Tensor) DataType() DataType {
+ return t.dt
+}
+
+// Shape returns the shape of the Tensor.
+func (t *Tensor) Shape() []int64 {
+ return t.shape
+}
+
+// Value converts the Tensor to a Go value. For now, not all Tensor types are
+// supported, and this function may panic if it encounters an unsupported
+// DataType.
+//
+// The type of the output depends on the Tensor type and rank. For example:
+// Tensor(int64, 0): int64
+// Tensor(float64, 3): [][][]float64
+func (t *Tensor) Value() interface{} {
+ typ, err := typeOf(t.DataType(), t.Shape())
+ if err != nil {
+ panic(err)
+ }
+ val := reflect.New(typ)
+ if err := decodeTensor(t.buf, t.Shape(), typ, val); err != nil {
+ panic(err)
+ }
+ return reflect.Indirect(val).Interface()
+}
+
+// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of
+// the *C.TF_Tensor, either by passing ownership to the C API or explicitly
+// calling C.TF_DeleteTensor() on it.
+func (t *Tensor) c() *C.TF_Tensor {
+ var shapePtr *C.int64_t
+ if len(t.shape) > 0 {
+ shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0]))
+ }
+ tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len()))
+ if t.buf.Len() > 0 {
+ slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210
+ C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len()))
+ }
+ return tensor
+}
+
+// deleteCTensor only exists to delete C.TF_Tensors in tests. go test doesn't
+// support cgo.
+func deleteCTensor(ct *C.TF_Tensor) {
+ C.TF_DeleteTensor(ct)
+}
+
+var types = []struct {
+ typ reflect.Type
+ dataType C.TF_DataType
+}{
+ {reflect.TypeOf(float32(0)), C.TF_FLOAT},
+ {reflect.TypeOf(float64(0)), C.TF_DOUBLE},
+ {reflect.TypeOf(int32(0)), C.TF_INT32},
+ {reflect.TypeOf(uint8(0)), C.TF_UINT8},
+ {reflect.TypeOf(int16(0)), C.TF_INT16},
+ {reflect.TypeOf(int8(0)), C.TF_INT8},
+ {reflect.TypeOf(""), C.TF_STRING},
+ {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64},
+ {reflect.TypeOf(int64(0)), C.TF_INT64},
+ {reflect.TypeOf(false), C.TF_BOOL},
+ {reflect.TypeOf(uint16(0)), C.TF_UINT16},
+ {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128},
+}
+
+// rankAndDataTypeOf returns the data type and rank of a Go type for use when
+// encoding. We fetch them separately from encoding to support 0-sized
+// dimensions.
+func rankAndDataTypeOf(typ reflect.Type) (int, DataType, error) {
+ rank := 0
+ elem := typ
+ for ; elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice; elem = elem.Elem() {
+ rank++
+ }
+ for _, t := range types {
+ if elem.Kind() == t.typ.Kind() {
+ return rank, DataType(t.dataType), nil
+ }
+ }
+ return 0, DataType(0), fmt.Errorf("unsupported type %v", typ)
+}
+
+// typeOf converts from a DataType and Shape to the equivalent Go type.
+func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
+ var ret reflect.Type
+ for _, t := range types {
+ if dt == DataType(t.dataType) {
+ ret = t.typ
+ break
+ }
+ }
+ if ret == nil {
+ return nil, fmt.Errorf("DataType %v unsupported", dt)
+ }
+ for _ = range shape {
+ ret = reflect.SliceOf(ret)
+ }
+ return ret, nil
+}
+
+// encodeTensor writes v to the specified buffer using the format specified in
+// c_api.h
+func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error {
+ switch v.Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ if err := binary.Write(buf, nativeEndian, v.Interface()); err != nil {
+ return err
+ }
+
+ case reflect.Array, reflect.Slice:
+ // If slice elements are slices, verify that all of them have the same size.
+ // Go's type system makes that guarantee for arrays.
+ if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice {
+ expected := v.Index(0).Len()
+ for i := 1; i < v.Len(); i++ {
+ if v.Index(i).Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected)
+ }
+ }
+ }
+
+ shape[0] = int64(v.Len())
+ for i := 0; i < v.Len(); i++ {
+ err := encodeTensor(buf, shape[1:], v.Index(i))
+ if err != nil {
+ return err
+ }
+ }
+
+ default:
+ return fmt.Errorf("unsupported type %v", v.Type())
+ }
+ return nil
+}
+
+// decodeTensor decodes the Tensor from the buffer to ptr using the format
+// specified in c_api.h
+func decodeTensor(buf *bytes.Buffer, shape []int64, typ reflect.Type, ptr reflect.Value) error {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ if err := binary.Read(buf, nativeEndian, ptr.Interface()); err != nil {
+ return err
+ }
+
+ case reflect.Slice:
+ val := reflect.Indirect(ptr)
+ val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
+ for i := 0; i < val.Len(); i++ {
+ if err := decodeTensor(buf, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
+ return err
+ }
+ }
+
+ default:
+ return fmt.Errorf("unsupported type %v", typ)
+ }
+ return nil
+}
+
+// nativeEndian is the byte order for the local platform. Used to send back and
+// forth Tensors with the C API. We test for endianness at runtime because
+// some architectures can be booted into different endian modes.
+var nativeEndian binary.ByteOrder
+
+func init() {
+ buf := [2]byte{}
+ *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
+
+ switch buf {
+ case [2]byte{0xCD, 0xAB}:
+ nativeEndian = binary.LittleEndian
+ case [2]byte{0xAB, 0xCD}:
+ nativeEndian = binary.BigEndian
+ default:
+ panic("Could not determine native endianness.")
+ }
+}