aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-12-01 12:10:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 12:22:59 -0800
commit0ba830748da7defb418e8448e62acd3077f90a52 (patch)
tree3e90101a33bd65c79db4325c64d4d242e7d99324 /tensorflow/go/tensor.go
parentc4d507ac75299a6aa2486cb5e6b3a37ff009d4ca (diff)
Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular it does not support string tensors. But the same API will be sufficient when that limitation has been addressed. See also #6003. The API added here can be used to fill in the tensor_contents field of a TensorProto protocol buffer. Change: 140760816
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go61
1 files changed, 61 insertions, 0 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index f755e9d4f8..c50e5e30aa 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -108,6 +108,35 @@ func NewTensor(value interface{}) (*Tensor, error) {
return t, nil
}
+// ReadTensor constructs a Tensor with the provided type and shape from the
+// serialized tensor contents in r.
+//
+// See also WriteContentsTo.
+func ReadTensor(dataType DataType, shape []int64, r io.Reader) (*Tensor, error) {
+ if err := isTensorSerializable(dataType); err != nil {
+ return nil, err
+ }
+ nbytes := typeOf(dataType, nil).Size() * uintptr(numElements(shape))
+ var shapePtr *C.int64_t
+ if len(shape) > 0 {
+ shapePtr = (*C.int64_t)(unsafe.Pointer(&shape[0]))
+ }
+ t := &Tensor{
+ c: C.TF_AllocateTensor(C.TF_DataType(dataType), shapePtr, C.int(len(shape)), C.size_t(nbytes)),
+ shape: shape,
+ }
+ runtime.SetFinalizer(t, (*Tensor).finalize)
+ raw := tensorData(t.c)
+ n, err := r.Read(raw)
+ if err != nil {
+ return nil, err
+ }
+ if uintptr(n) != nbytes {
+ return nil, fmt.Errorf("expected serialized tensor to be %v bytes, read %v", nbytes, n)
+ }
+ return t, nil
+}
+
// newTensorFromC takes ownership of c and returns the owning Tensor.
func newTensorFromC(c *C.TF_Tensor) *Tensor {
var shape []int64
@@ -156,6 +185,21 @@ func (t *Tensor) Value() interface{} {
return reflect.Indirect(val).Interface()
}
+// WriteContentsTo writes the serialized contents of t to w.
+//
+// Returns the number of bytes written. See ReadTensor for
+// reconstructing a Tensor from the serialized form.
+//
+// WARNING: WriteContentsTo is not comprehensive and will fail
+// if t.DataType() is non-numeric (e.g., String). See
+// https://github.com/tensorflow/tensorflow/issues/6003.
+func (t *Tensor) WriteContentsTo(w io.Writer) (int64, error) {
+ if err := isTensorSerializable(t.DataType()); err != nil {
+ return 0, err
+ }
+ return io.Copy(w, bytes.NewReader(tensorData(t.c)))
+}
+
func tensorData(c *C.TF_Tensor) []byte {
// See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
cbytes := C.TF_TensorData(c)
@@ -385,6 +429,23 @@ func bug(format string, args ...interface{}) error {
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
}
+func isTensorSerializable(dataType DataType) error {
+ // For numeric types, the serialized Tensor matches the in-memory
+ // representation. See the implementation of Tensor::AsProtoContent in
+ // https://www.tensorflow.org/code/tensorflow/core/framework/tensor.cc
+ //
+ // The more appropriate way to be in sync with Tensor::AsProtoContent
+ // would be to have the TensorFlow C library export functions for
+ // serialization and deserialization of Tensors. Till then capitalize
+ // on knowledge of the implementation for numeric types.
+ switch dataType {
+ case Float, Double, Int32, Uint8, Int16, Int8, Complex, Int64, Bool, Quint8, Qint32, Bfloat16, Qint16, Quint16, Uint16, Complex128, Half:
+ return nil
+ default:
+ return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
+ }
+}
+
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.