diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/go/genop/internal/genop.go | 2 | ||||
-rw-r--r-- | tensorflow/go/graph.go | 35 | ||||
-rw-r--r-- | tensorflow/go/op/op_test.go | 33 | ||||
-rw-r--r-- | tensorflow/go/operation_test.go | 15 | ||||
-rw-r--r-- | tensorflow/go/shape.go | 102 | ||||
-rw-r--r-- | tensorflow/go/shape_test.go | 83 |
6 files changed, 264 insertions, 6 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index 75c111e957..d9ebec0f8c 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -395,7 +395,7 @@ func goType(tfType string) (string, error) { case "type": gotype = "tf.DataType" case "shape": - gotype = "[]int64" + gotype = "tf.Shape" case "tensor": gotype = "tf.Tensor" case "string": diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 2eb1194610..c0f91ffb30 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -259,13 +259,38 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } + case Shape: + ndims, dims := cshape(value) + var dimsp *C.int64_t + if ndims > 0 { + dimsp = &dims[0] + } + C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) + case []Shape: + ndims := make([]C.int, len(value)) + dims := make([][]C.int64_t, len(value)) + dimsp := make([]*C.int64_t, len(value)) + for i, s := range value { + ndims[i], dims[i] = cshape(s) + if ndims[i] > 0 { + dimsp[i] = &dims[i][0] + } + } + C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) default: - // Shapes can be done, but will require that it be - // distinguishable from []int64. Which is fine, it - // probably makes sense to define a Shape type anyway, - // since that should handle partially known shapes as - // well and hide the special meaning of -1? return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil } + +func cshape(s Shape) (C.int, []C.int64_t) { + ndims := C.int(s.NumDimensions()) + if ndims < 0 { + return -1, nil + } + dims := make([]C.int64_t, ndims) + for i, s := range s.dims { + dims[i] = C.int64_t(s) + } + return ndims, dims +} diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go new file mode 100644 index 0000000000..eaa27bfcd0 --- /dev/null +++ b/tensorflow/go/op/op_test.go @@ -0,0 +1,33 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tests for the generated code of some operations. + +package op + +import ( + "testing" + + tf "github.com/tensorflow/tensorflow/tensorflow/go" +) + +func TestPlaceholder(t *testing.T) { + s := NewScope() + Placeholder(s.SubScope("x"), tf.Float, PlaceholderShape(tf.MakeShape(-1, 10))) + Placeholder(s.SubScope("y"), tf.Float, PlaceholderShape(tf.ScalarShape())) + Placeholder(s.SubScope("z"), tf.Float, PlaceholderShape(tf.Shape{})) + if _, err := s.Finalize(); err != nil { + t.Fatal(err) + } +} diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 8080515ee9..4c4c960448 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -81,6 +81,21 @@ func TestOperationOutputListSize(t *testing.T) { } } +func TestOperationShapeAttribute(t *testing.T) { + g := NewGraph() + _, err := g.AddOperation(OpSpec{ + Type: "Placeholder", + Attrs: map[string]interface{}{ + "dtype": Float, + "shape": MakeShape(-1, 3), + }, + }) + if err != nil { + t.Fatal(err) + } + // If and when the API to get attributes is added, check that here. +} + func TestOutputShape(t *testing.T) { graph := NewGraph() testdata := []struct { diff --git a/tensorflow/go/shape.go b/tensorflow/go/shape.go new file mode 100644 index 0000000000..c48bbf29a3 --- /dev/null +++ b/tensorflow/go/shape.go @@ -0,0 +1,102 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "fmt" + "strings" +) + +// Shape represents the (possibly partially known) shape of a tensor that will +// be produced by an operation. +// +// The zero-value of a Shape represents a shape with an unknown number of +// dimensions. +type Shape struct { + dims []int64 +} + +// ScalarShape returns a Shape representing a scalar. +func ScalarShape() Shape { + return Shape{dims: make([]int64, 0)} +} + +// MakeShape returns a Shape with the provided size of each dimension. +// +// A value of -1 implies that the size of the corresponding dimension is not +// known. +func MakeShape(shape ...int64) Shape { + cpy := make([]int64, len(shape)) + copy(cpy, shape) + return Shape{dims: cpy} +} + +// NumDimensions returns the number of dimensions represented by s, or -1 if +// unknown. +func (s Shape) NumDimensions() int { + if s.dims == nil { + return -1 + } + return len(s.dims) +} + +// Size returns the size of the dim-th dimension of the shape, or -1 if it +// is unknown. +// +// REQUIRES: 0 <= dim < s.NumDimensions() +func (s Shape) Size(dim int) int64 { + if dim < 0 || dim > s.NumDimensions() { + return -1 + } + return s.dims[dim] +} + +// IsFullySpecified returns true iff the size of all the dimensions of s are +// known. +func (s Shape) IsFullySpecified() bool { + if s.dims == nil { + return false + } + for _, size := range s.dims { + if size <= 1 { + return false + } + } + return true +} + +// ToSlice returns the (possibly partially known) shape represented by s as a +// slice, or an error if the number of dimensions is not known. +func (s Shape) ToSlice() ([]int64, error) { + if s.dims == nil { + return nil, fmt.Errorf("cannot create a slice for a Shape with an unknown number of dimensions") + } + cpy := make([]int64, len(s.dims)) + copy(cpy, s.dims) + return cpy, nil +} + +func (s Shape) String() string { + if s.dims == nil { + return "?" + } + ret := fmt.Sprint(s.dims) + for _, size := range s.dims { + if size < 0 { + ret = strings.Replace(ret, fmt.Sprint(size), "?", 1) + } + } + return strings.Replace(ret, " ", ", ", -1) +} diff --git a/tensorflow/go/shape_test.go b/tensorflow/go/shape_test.go new file mode 100644 index 0000000000..f8f3d4e94b --- /dev/null +++ b/tensorflow/go/shape_test.go @@ -0,0 +1,83 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "fmt" + "reflect" + "testing" +) + +func TestShape(t *testing.T) { + tests := []struct { + shape Shape + slice []int64 + full bool + str string + }{ + { + shape: ScalarShape(), + slice: make([]int64, 0), + full: true, + str: "[]", + }, + { + shape: MakeShape(-1, 2, -1, 4), + slice: []int64{-1, 2, -1, 4}, + full: false, + str: "[?, 2, ?, 4]", + }, + { + shape: MakeShape(2, 3), + slice: []int64{2, 3}, + full: true, + str: "[2, 3]", + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%#v", test.shape), func(t *testing.T) { + if got, want := test.shape.NumDimensions(), len(test.slice); got != want { + t.Errorf("Got %v, want %v", got, want) + } + if gotSlice, err := test.shape.ToSlice(); err != nil || !reflect.DeepEqual(gotSlice, test.slice) { + t.Errorf("Got (%#v, %v), want (%#v, nil)", gotSlice, err, test.slice) + } + if got, want := test.shape.IsFullySpecified(), test.full; got != want { + t.Errorf("Got %v, want %v", got, want) + } + if got, want := test.shape.String(), test.str; got != want { + t.Errorf("Got %v, want %v", got, want) + } + }) + } + +} + +func TestZeroShape(t *testing.T) { + var s Shape + if s.NumDimensions() != -1 { + t.Error(s.NumDimensions()) + } + if _, err := s.ToSlice(); err == nil { + t.Error("ToSlice() on a Shape of unknown number of dimensions should fail") + } + if s.IsFullySpecified() { + t.Error("Shape of unknown number of dimensions should not be fully specified") + } + if got, want := s.String(), "?"; got != want { + t.Errorf("Got %q, want %q", got, want) + } + +} |