diff options
author | Asim Shankar <ashankar@google.com> | 2016-12-01 12:10:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-12-01 12:22:59 -0800 |
commit | 0ba830748da7defb418e8448e62acd3077f90a52 (patch) | |
tree | 3e90101a33bd65c79db4325c64d4d242e7d99324 /tensorflow/go/tensor_test.go | |
parent | c4d507ac75299a6aa2486cb5e6b3a37ff009d4ca (diff) |
Go: API functions to serialize/deserialize Tensors.
The current implementation has limitations, in particular
it does not support string tensors. But the same API
will be sufficient when that limitation has been
addressed.
See also #6003. The API added here can be used to
fill in the tensor_contents field of a TensorProto
protocol buffer.
Change: 140760816
Diffstat (limited to 'tensorflow/go/tensor_test.go')
-rw-r--r-- | tensorflow/go/tensor_test.go | 110 |
1 files changed, 109 insertions, 1 deletions
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 073da0cc6e..2a3ed416bd 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -15,6 +15,7 @@ package tensorflow import ( + "bytes" "reflect" "testing" ) @@ -28,7 +29,6 @@ func TestNewTensor(t *testing.T) { {nil, int16(5)}, {nil, int32(5)}, {nil, int64(5)}, - {nil, int64(5)}, {nil, uint8(5)}, {nil, uint16(5)}, {nil, float32(5)}, @@ -103,6 +103,114 @@ func TestNewTensor(t *testing.T) { } } +func TestTensorSerialization(t *testing.T) { + var tests = []interface{}{ + int8(5), + int16(5), + int32(5), + int64(5), + uint8(5), + uint16(5), + float32(5), + float64(5), + complex(float32(5), float32(6)), + complex(float64(5), float64(6)), + []float64{1}, + [][]float32{{1, 2}, {3, 4}, {5, 6}}, + [][][]int8{ + {{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}, + {{0, -1}, {-2, -3}, {-4, -5}}, + {{-6, -7}, {-8, -9}, {-10, -11}}, + }, + } + for _, v := range tests { + t1, err := NewTensor(v) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + buf := new(bytes.Buffer) + n, err := t1.WriteContentsTo(buf) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + if n != int64(buf.Len()) { + t.Errorf("(%v): WriteContentsTo said it wrote %v bytes, but wrote %v", v, n, buf.Len()) + } + t2, err := ReadTensor(t1.DataType(), t1.Shape(), buf) + if err != nil { + t.Errorf("(%v): %v", v, err) + continue + } + if buf.Len() != 0 { + t.Errorf("(%v): %v bytes written by WriteContentsTo not read by ReadTensor", v, buf.Len()) + } + if got, want := t2.DataType(), t1.DataType(); got != want { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + if got, want := t2.Shape(), t1.Shape(); !reflect.DeepEqual(got, want) { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + if got, want := t2.Value(), v; !reflect.DeepEqual(got, want) { + t.Errorf("(%v): Got %v, want %v", v, got, want) + } + } +} + +func TestReadTensorDoesNotReadBeyondContent(t *testing.T) { + t1, _ := NewTensor(int8(7)) + t2, _ := NewTensor(float32(2.718)) + buf := new(bytes.Buffer) + if _, err := t1.WriteContentsTo(buf); err != nil { + t.Fatal(err) + } + if _, err := t2.WriteContentsTo(buf); err != nil { + t.Fatal(err) + } + + t3, err := ReadTensor(t1.DataType(), t1.Shape(), buf) + if err != nil { + t.Fatal(err) + } + t4, err := ReadTensor(t2.DataType(), t2.Shape(), buf) + if err != nil { + t.Fatal(err) + } + + if v, ok := t3.Value().(int8); !ok || v != 7 { + t.Errorf("Got (%v (%T), %v), want (7 (int8), true)", v, v, ok) + } + if v, ok := t4.Value().(float32); !ok || v != 2.718 { + t.Errorf("Got (%v (%T), %v), want (2.718 (float32), true)", v, v, ok) + } +} + +func TestTensorSerializationErrors(t *testing.T) { + // String tensors cannot be serialized + t1, err := NewTensor("abcd") + if err != nil { + t.Fatal(err) + } + buf := new(bytes.Buffer) + if n, err := t1.WriteContentsTo(buf); n != 0 || err == nil || buf.Len() != 0 { + t.Errorf("Got (%v, %v, %v) want (0, <non-nil>, 0)", n, err, buf.Len()) + } + // Should fail to read a truncated value. + if t1, err = NewTensor(int8(8)); err != nil { + t.Fatal(err) + } + n, err := t1.WriteContentsTo(buf) + if err != nil { + t.Fatal(err) + } + r := bytes.NewReader(buf.Bytes()[:n-1]) + if _, err = ReadTensor(t1.DataType(), t1.Shape(), r); err == nil { + t.Error("ReadTensor should have failed if the tensor content was truncated") + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil { |