diff options
author | Asim Shankar <ashankar@google.com> | 2016-12-01 12:10:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-01 12:22:59 -0800 |
commit | 0ba830748da7defb418e8448e62acd3077f90a52 (patch) | |
tree | 3e90101a33bd65c79db4325c64d4d242e7d99324 /tensorflow/go/tensor.go | |
parent | c4d507ac75299a6aa2486cb5e6b3a37ff009d4ca (diff) |
Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular
it does not support string tensors. But the same API
will be sufficient when that limitation has been
addressed.
See also #6003. The API added here can be used to
fill in the tensor_contents field of a TensorProto
protocol buffer.
Change: 140760816
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r-- | tensorflow/go/tensor.go | 61 |
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index f755e9d4f8..c50e5e30aa 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) { return t, nil } +// ReadTensor constructs a Tensor with the provided type and shape from the +// serialized tensor contents in r. +// +// See also WriteContentsTo. +func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) { + if err := isTensorSerializable(dataType); err != nil { + return nil, err + } + nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape)) + var shapePtr *C.int64_t + if len(shape) > 0 { + shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0])) + } + t := &Tensor{ + c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)), + shape: shape, + } + runtime.SetFinalizer(t, (*Tensor).finalize) + raw := tensorData(t.c) + n, err := r.Read(raw) + if err != nil { + return nil, err + } + if uintptr(n) != nbytes { + return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n) + } + return t, nil +} + // newTensorFromC takes ownership of c and returns the owning Tensor. func newTensorFromC(c *C.TF_Tensor) *Tensor { var shape []int64 @@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} { return reflect.Indirect(val).Interface() } +// WriteContentsTo writes the serialized contents of t to w. +// +// Returns the number of bytes written. See ReadTensor for +// reconstructing a Tensor from the serialized form. +// +// WARNING: WriteContentsTo is not comprehensive and will fail +// if t.DataType() is non-numeric (e.g., String). See +// https://github.com/tensorflow/tensorflow/issues/6003. +func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) { + if err := isTensorSerializable(t.DataType()); err != nil { + return 0, err + } + return io.Copy(w, bytes.NewReader(tensorData(t.c))) +} + func tensorData(c *C.TF_Tensor) []byte { // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices cbytes := C.TF_TensorData(c) @@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error { return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) } +func isTensorSerializable(dataType DataType) error { + // For numeric types, the serialized Tensor matches the in-memory + // representation. See the implementation of Tensor::AsProtoContent in + // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc + // + // The more appropriate way to be in sync with Tensor::AsProtoContent + // would be to have the TensorFlow C library export functions for + // serialization and deserialization of Tensors. Till then capitalize + // on knowledge of the implementation for numeric types. + switch dataType { + case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half: + return nil + default: + return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) + } +} + // nativeEndian is the byte order for the local platform. Used to send back and // forth Tensors with the C API. We test for endianness at runtime because // some architectures can be booted into different endian modes. |