diff options
author | 2017-01-24 16:13:44 -0800 | |
---|---|---|
committer | 2017-01-24 16:26:57 -0800 | |
commit | c995ee586cfcce29a10b6e05140f1cc7c6c13a16 (patch) | |
tree | 526cf540971cde16e63356245659a2aee44e635b /tensorflow/go/operation_test.go | |
parent | 4ac5ed3184873d530dc249a74794a55229e85e0f (diff) |
Go: Output.Shape now returns a Shape object.
Output.Shape may be only partially known, hence the recently introduced Shape
type is a more appropriate return value.
Change: 145481782
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 17 |
1 files changed, 7 insertions, 10 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 4c4c960448..32e2989179 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -124,16 +124,13 @@ func TestOutputShape(t *testing.T) { if err != nil { t.Fatal(err) } - shape, err := c.Shape() - if err != nil { - t.Fatal(err) - } - if got, want := len(shape), len(test.Shape); got != want { + shape := c.Shape() + if got, want := shape.NumDimensions(), len(test.Shape); got != want { t.Fatalf("Got a shape with %d dimensions, want %d", got, want) } - for i := 0; i < len(shape); i++ { - if got, want := shape[i], test.Shape[i]; got != want { - t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(shape)) + for i := 0; i < len(test.Shape); i++ { + if got, want := shape.Size(i), test.Shape[i]; got != want { + t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(test.Shape)) } } }) @@ -147,8 +144,8 @@ func TestOutputShape(t *testing.T) { if err != nil { t.Fatal(err) } - if shape, err := placeholder.Shape(); err == nil { - t.Errorf("Got shape %v, wanted error", shape) + if shape := placeholder.Shape(); shape.NumDimensions() != -1 { + t.Errorf("Got shape %v, wanted an unknown number of dimensions", shape) } } |