aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/go/genop/internal/genop.go2
-rw-r--r--tensorflow/go/graph.go35
-rw-r--r--tensorflow/go/op/op_test.go33
-rw-r--r--tensorflow/go/operation_test.go15
-rw-r--r--tensorflow/go/shape.go102
-rw-r--r--tensorflow/go/shape_test.go83
6 files changed, 264 insertions, 6 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go
index 75c111e957..d9ebec0f8c 100644
--- a/tensorflow/go/genop/internal/genop.go
+++ b/tensorflow/go/genop/internal/genop.go
@@ -395,7 +395,7 @@ func goType(tfType string) (string, error) {
case "type":
gotype = "tf.DataType"
case "shape":
- gotype = "[]int64"
+ gotype = "tf.Shape"
case "tensor":
gotype = "tf.Tensor"
case "string":
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 2eb1194610..c0f91ffb30 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -259,13 +259,38 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
if err := status.Err(); err != nil {
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
+ case Shape:
+ ndims, dims := cshape(value)
+ var dimsp *C.int64_t
+ if ndims > 0 {
+ dimsp = &dims[0]
+ }
+ C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
+ case []Shape:
+ ndims := make([]C.int, len(value))
+ dims := make([][]C.int64_t, len(value))
+ dimsp := make([]*C.int64_t, len(value))
+ for i, s := range value {
+ ndims[i], dims[i] = cshape(s)
+ if ndims[i] > 0 {
+ dimsp[i] = &dims[i][0]
+ }
+ }
+ C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
default:
- // Shapes can be done, but will require that it be
- // distinguishable from []int64. Which is fine, it
- // probably makes sense to define a Shape type anyway,
- // since that should handle partially known shapes as
- // well and hide the special meaning of -1?
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}
return nil
}
+
+func cshape(s Shape) (C.int, []C.int64_t) {
+ ndims := C.int(s.NumDimensions())
+ if ndims < 0 {
+ return -1, nil
+ }
+ dims := make([]C.int64_t, ndims)
+ for i, s := range s.dims {
+ dims[i] = C.int64_t(s)
+ }
+ return ndims, dims
+}
diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go
new file mode 100644
index 0000000000..eaa27bfcd0
--- /dev/null
+++ b/tensorflow/go/op/op_test.go
@@ -0,0 +1,33 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Tests for the generated code of some operations.
+
+package op
+
+import (
+ "testing"
+
+ tf "github.com/tensorflow/tensorflow/tensorflow/go"
+)
+
+func TestPlaceholder(t *testing.T) {
+ s := NewScope()
+ Placeholder(s.SubScope("x"), tf.Float, PlaceholderShape(tf.MakeShape(-1, 10)))
+ Placeholder(s.SubScope("y"), tf.Float, PlaceholderShape(tf.ScalarShape()))
+ Placeholder(s.SubScope("z"), tf.Float, PlaceholderShape(tf.Shape{}))
+ if _, err := s.Finalize(); err != nil {
+ t.Fatal(err)
+ }
+}
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 8080515ee9..4c4c960448 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -81,6 +81,21 @@ func TestOperationOutputListSize(t *testing.T) {
}
}
+func TestOperationShapeAttribute(t *testing.T) {
+ g := NewGraph()
+ _, err := g.AddOperation(OpSpec{
+ Type: "Placeholder",
+ Attrs: map[string]interface{}{
+ "dtype": Float,
+ "shape": MakeShape(-1, 3),
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ // If and when the API to get attributes is added, check that here.
+}
+
func TestOutputShape(t *testing.T) {
graph := NewGraph()
testdata := []struct {
diff --git a/tensorflow/go/shape.go b/tensorflow/go/shape.go
new file mode 100644
index 0000000000..c48bbf29a3
--- /dev/null
+++ b/tensorflow/go/shape.go
@@ -0,0 +1,102 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "fmt"
+ "strings"
+)
+
+// Shape represents the (possibly partially known) shape of a tensor that will
+// be produced by an operation.
+//
+// The zero-value of a Shape represents a shape with an unknown number of
+// dimensions.
+type Shape struct {
+ dims []int64
+}
+
+// ScalarShape returns a Shape representing a scalar.
+func ScalarShape() Shape {
+ return Shape{dims: make([]int64, 0)}
+}
+
+// MakeShape returns a Shape with the provided size of each dimension.
+//
+// A value of -1 implies that the size of the corresponding dimension is not
+// known.
+func MakeShape(shape ...int64) Shape {
+ cpy := make([]int64, len(shape))
+ copy(cpy, shape)
+ return Shape{dims: cpy}
+}
+
+// NumDimensions returns the number of dimensions represented by s, or -1 if
+// unknown.
+func (s Shape) NumDimensions() int {
+ if s.dims == nil {
+ return -1
+ }
+ return len(s.dims)
+}
+
+// Size returns the size of the dim-th dimension of the shape, or -1 if it
+// is unknown.
+//
+// REQUIRES: 0 <= dim < s.NumDimensions()
+func (s Shape) Size(dim int) int64 {
+ if dim < 0 || dim > s.NumDimensions() {
+ return -1
+ }
+ return s.dims[dim]
+}
+
+// IsFullySpecified returns true iff the size of all the dimensions of s are
+// known.
+func (s Shape) IsFullySpecified() bool {
+ if s.dims == nil {
+ return false
+ }
+ for _, size := range s.dims {
+ if size <= 1 {
+ return false
+ }
+ }
+ return true
+}
+
+// ToSlice returns the (possibly partially known) shape represented by s as a
+// slice, or an error if the number of dimensions is not known.
+func (s Shape) ToSlice() ([]int64, error) {
+ if s.dims == nil {
+ return nil, fmt.Errorf("cannot create a slice for a Shape with an unknown number of dimensions")
+ }
+ cpy := make([]int64, len(s.dims))
+ copy(cpy, s.dims)
+ return cpy, nil
+}
+
+func (s Shape) String() string {
+ if s.dims == nil {
+ return "?"
+ }
+ ret := fmt.Sprint(s.dims)
+ for _, size := range s.dims {
+ if size < 0 {
+ ret = strings.Replace(ret, fmt.Sprint(size), "?", 1)
+ }
+ }
+ return strings.Replace(ret, " ", ", ", -1)
+}
diff --git a/tensorflow/go/shape_test.go b/tensorflow/go/shape_test.go
new file mode 100644
index 0000000000..f8f3d4e94b
--- /dev/null
+++ b/tensorflow/go/shape_test.go
@@ -0,0 +1,83 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestShape(t *testing.T) {
+ tests := []struct {
+ shape Shape
+ slice []int64
+ full bool
+ str string
+ }{
+ {
+ shape: ScalarShape(),
+ slice: make([]int64, 0),
+ full: true,
+ str: "[]",
+ },
+ {
+ shape: MakeShape(-1, 2, -1, 4),
+ slice: []int64{-1, 2, -1, 4},
+ full: false,
+ str: "[?, 2, ?, 4]",
+ },
+ {
+ shape: MakeShape(2, 3),
+ slice: []int64{2, 3},
+ full: true,
+ str: "[2, 3]",
+ },
+ }
+ for _, test := range tests {
+ t.Run(fmt.Sprintf("%#v", test.shape), func(t *testing.T) {
+ if got, want := test.shape.NumDimensions(), len(test.slice); got != want {
+ t.Errorf("Got %v, want %v", got, want)
+ }
+ if gotSlice, err := test.shape.ToSlice(); err != nil || !reflect.DeepEqual(gotSlice, test.slice) {
+ t.Errorf("Got (%#v, %v), want (%#v, nil)", gotSlice, err, test.slice)
+ }
+ if got, want := test.shape.IsFullySpecified(), test.full; got != want {
+ t.Errorf("Got %v, want %v", got, want)
+ }
+ if got, want := test.shape.String(), test.str; got != want {
+ t.Errorf("Got %v, want %v", got, want)
+ }
+ })
+ }
+
+}
+
+func TestZeroShape(t *testing.T) {
+ var s Shape
+ if s.NumDimensions() != -1 {
+ t.Error(s.NumDimensions())
+ }
+ if _, err := s.ToSlice(); err == nil {
+ t.Error("ToSlice() on a Shape of unknown number of dimensions should fail")
+ }
+ if s.IsFullySpecified() {
+ t.Error("Shape of unknown number of dimensions should not be fully specified")
+ }
+ if got, want := s.String(), "?"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+
+}