aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-12-01 12:10:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 12:22:59 -0800
commit0ba830748da7defb418e8448e62acd3077f90a52 (patch)
tree3e90101a33bd65c79db4325c64d4d242e7d99324 /tensorflow/go/tensor_test.go
parentc4d507ac75299a6aa2486cb5e6b3a37ff009d4ca (diff)
Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular it does not support string tensors. But the same API will be sufficient when that limitation has been addressed. See also #6003. The API added here can be used to fill in the tensor_contents field of a TensorProto protocol buffer. Change: 140760816
Diffstat (limited to 'tensorflow/go/tensor_test.go')
-rw-r--r--tensorflow/go/tensor_test.go110
1 files changed, 109 insertions, 1 deletions
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 073da0cc6e..2a3ed416bd 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -15,6 +15,7 @@
package tensorflow
import (
+ "bytes"
"reflect"
"testing"
)
@@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) {
{nil, int16(5)},
{nil, int32(5)},
{nil, int64(5)},
- {nil, int64(5)},
{nil, uint8(5)},
{nil, uint16(5)},
{nil, float32(5)},
@@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) {
}
}
+func TestTensorSerialization(t *testing.T) {
+ var tests = []interface{}{
+ int8(5),
+ int16(5),
+ int32(5),
+ int64(5),
+ uint8(5),
+ uint16(5),
+ float32(5),
+ float64(5),
+ complex(float32(5), float32(6)),
+ complex(float64(5), float64(6)),
+ []float64{1},
+ [][]float32{{1, 2}, {3, 4}, {5, 6}},
+ [][][]int8{
+ {{1, 2}, {3, 4}, {5, 6}},
+ {{7, 8}, {9, 10}, {11, 12}},
+ {{0, -1}, {-2, -3}, {-4, -5}},
+ {{-6, -7}, {-8, -9}, {-10, -11}},
+ },
+ }
+ for _, v := range tests {
+ t1, err := NewTensor(v)
+ if err != nil {
+ t.Errorf("(%v): %v", v, err)
+ continue
+ }
+ buf := new(bytes.Buffer)
+ n, err := t1.WriteContentsTo(buf)
+ if err != nil {
+ t.Errorf("(%v): %v", v, err)
+ continue
+ }
+ if n != int64(buf.Len()) {
+ t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len())
+ }
+ t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
+ if err != nil {
+ t.Errorf("(%v): %v", v, err)
+ continue
+ }
+ if buf.Len() != 0 {
+ t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len())
+ }
+ if got, want := t2.DataType(), t1.DataType(); got != want {
+ t.Errorf("(%v): Got %v, want %v", v, got, want)
+ }
+ if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v): Got %v, want %v", v, got, want)
+ }
+ if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) {
+ t.Errorf("(%v): Got %v, want %v", v, got, want)
+ }
+ }
+}
+
+func TestReadTensorDoesNotReadBeyondContent(t *testing.T) {
+ t1, _ := NewTensor(int8(7))
+ t2, _ := NewTensor(float32(2.718))
+ buf := new(bytes.Buffer)
+ if _, err := t1.WriteContentsTo(buf); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := t2.WriteContentsTo(buf); err != nil {
+ t.Fatal(err)
+ }
+
+ t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if v, ok := t3.Value().(int8); !ok || v != 7 {
+ t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok)
+ }
+ if v, ok := t4.Value().(float32); !ok || v != 2.718 {
+ t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok)
+ }
+}
+
+func TestTensorSerializationErrors(t *testing.T) {
+ // String tensors cannot be serialized
+ t1, err := NewTensor("abcd")
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := new(bytes.Buffer)
+ if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 {
+ t.Errorf("Got (%v, %v, %v) want (0, <non-nil>, 0)", n, err, buf.Len())
+ }
+ // Should fail to read a truncated value.
+ if t1, err = NewTensor(int8(8)); err != nil {
+ t.Fatal(err)
+ }
+ n, err := t1.WriteContentsTo(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ r := bytes.NewReader(buf.Bytes()[:n-1])
+ if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil {
+ t.Error("ReadTensor should have failed if the tensor content was truncated")
+ }
+}
+
func benchmarkNewTensor(b *testing.B, v interface{}) {
for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil {