diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-18 12:36:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-18 12:39:21 -0700 |
commit | 07359dda7ff03d8a7b0d62f75e6c93fb22151a18 (patch) | |
tree | a4433fc72df34b760f4f1be4d3f9c3ec3502e825 /tensorflow/go/tensor_test.go | |
parent | 34c45c23e21929bd13b6a9cb92c62c1e7cbba8a5 (diff) |
fix ReadTensor not reading the full contents of reader
PiperOrigin-RevId: 201040414
Diffstat (limited to 'tensorflow/go/tensor_test.go')
-rw-r--r-- | tensorflow/go/tensor_test.go | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go index 793c36dd4d..dc533cd3e1 100644 --- a/tensorflow/go/tensor_test.go +++ b/tensorflow/go/tensor_test.go @@ -18,6 +18,7 @@ package tensorflow import ( "bytes" + "io" "reflect" "testing" ) @@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) { } } +func TestReadTensorReadAll(t *testing.T) { + // Get the bytes of a tensor. + a := []float32{1.1, 1.2, 1.3} + ats, err := NewTensor(a) + if err != nil { + t.Fatal(err) + } + abuf := new(bytes.Buffer) + if _, err := ats.WriteContentsTo(abuf); err != nil { + t.Fatal(err) + } + + // Get the bytes of another tensor. + b := []float32{1.1, 1.2, 1.3} + bts, err := NewTensor(b) + if err != nil { + t.Fatal(err) + } + bbuf := new(bytes.Buffer) + if _, err := bts.WriteContentsTo(bbuf); err != nil { + t.Fatal(err) + } + + // Check that ReadTensor reads all bytes of both tensors, when the situation + // requires one than reads. + abbuf := io.MultiReader(abuf, bbuf) + abts, err := ReadTensor(Float, []int64{2, 3}, abbuf) + if err != nil { + t.Fatal(err) + } + abtsf32 := abts.Value().([][]float32) + expected := [][]float32{a, b} + + if len(abtsf32) != 2 { + t.Fatalf("first dimension %d is not 2", len(abtsf32)) + } + for i := 0; i < 2; i++ { + if len(abtsf32[i]) != 3 { + t.Fatalf("second dimension %d is not 3", len(abtsf32[i])) + } + for j := 0; j < 3; j++ { + if abtsf32[i][j] != expected[i][j] { + t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j]) + } + } + } +} + func benchmarkNewTensor(b *testing.B, v interface{}) { for i := 0; i < b.N; i++ { if t, err := NewTensor(v); err != nil || t == nil { |