aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor_test.go
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 12:36:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 12:39:21 -0700
commit07359dda7ff03d8a7b0d62f75e6c93fb22151a18 (patch)
treea4433fc72df34b760f4f1be4d3f9c3ec3502e825 /tensorflow/go/tensor_test.go
parent34c45c23e21929bd13b6a9cb92c62c1e7cbba8a5 (diff)
fix ReadTensor not reading the full contents of reader
PiperOrigin-RevId: 201040414
Diffstat (limited to 'tensorflow/go/tensor_test.go')
-rw-r--r--tensorflow/go/tensor_test.go49
1 files changed, 49 insertions, 0 deletions
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
index 793c36dd4d..dc533cd3e1 100644
--- a/tensorflow/go/tensor_test.go
+++ b/tensorflow/go/tensor_test.go
@@ -18,6 +18,7 @@ package tensorflow
import (
"bytes"
+ "io"
"reflect"
"testing"
)
@@ -226,6 +227,54 @@ func TestTensorSerializationErrors(t *testing.T) {
}
}
+func TestReadTensorReadAll(t *testing.T) {
+ // Get the bytes of a tensor.
+ a := []float32{1.1, 1.2, 1.3}
+ ats, err := NewTensor(a)
+ if err != nil {
+ t.Fatal(err)
+ }
+ abuf := new(bytes.Buffer)
+ if _, err := ats.WriteContentsTo(abuf); err != nil {
+ t.Fatal(err)
+ }
+
+ // Get the bytes of another tensor.
+ b := []float32{1.1, 1.2, 1.3}
+ bts, err := NewTensor(b)
+ if err != nil {
+ t.Fatal(err)
+ }
+ bbuf := new(bytes.Buffer)
+ if _, err := bts.WriteContentsTo(bbuf); err != nil {
+ t.Fatal(err)
+ }
+
+ // Check that ReadTensor reads all bytes of both tensors, when the situation
+ // requires one than reads.
+ abbuf := io.MultiReader(abuf, bbuf)
+ abts, err := ReadTensor(Float, []int64{2, 3}, abbuf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ abtsf32 := abts.Value().([][]float32)
+ expected := [][]float32{a, b}
+
+ if len(abtsf32) != 2 {
+ t.Fatalf("first dimension %d is not 2", len(abtsf32))
+ }
+ for i := 0; i < 2; i++ {
+ if len(abtsf32[i]) != 3 {
+ t.Fatalf("second dimension %d is not 3", len(abtsf32[i]))
+ }
+ for j := 0; j < 3; j++ {
+ if abtsf32[i][j] != expected[i][j] {
+ t.Errorf("value at %d %d not equal %f %f", i, j, abtsf32[i][j], expected[i][j])
+ }
+ }
+ }
+}
+
func benchmarkNewTensor(b *testing.B, v interface{}) {
for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil {