aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go
diff options
context:
space:
mode:
authorGravatar Tristan Rice <rice@fn.lc>2018-06-18 12:43:51 -0700
committerGravatar Tristan Rice <rice@fn.lc>2018-06-19 14:50:20 -0700
commit577b256460dfca4e7c429437dded48e76715fee7 (patch)
tree60cae79034b1a5ea59935f87d0b3ea766e3a516b /tensorflow/go
parentff7e6399443615675a3f1182c4f2e1850008da04 (diff)
tensorflow/go: add tests for zero length arrays passed to C
Diffstat (limited to 'tensorflow/go')
-rw-r--r--tensorflow/go/attrs.go36
-rw-r--r--tensorflow/go/attrs_test.go172
-rw-r--r--tensorflow/go/operation.go3
-rw-r--r--tensorflow/go/operation_test.go4
4 files changed, 198 insertions, 17 deletions
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go
index bfa60d2aa8..f86c5737bc 100644
--- a/tensorflow/go/attrs.go
+++ b/tensorflow/go/attrs.go
@@ -33,7 +33,8 @@ func makeCShape(shape []C.int64_t) Shape {
return s
}
-// Attr returns the value of an attribute on op.
+// Attr returns the value of an attribute on op. It returns an error if the
+// attribute does not exist.
func (op *Operation) Attr(name string) (interface{}, error) {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
@@ -55,9 +56,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
switch meta._type {
case C.TF_ATTR_STRING:
+ if meta.list_size == 0 {
+ return []string(nil), nil
+ }
values := make([]unsafe.Pointer, meta.list_size)
lengths := make([]C.size_t, meta.list_size)
- storage := make([]C.char, meta.total_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.char, meta.total_size+1)
C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
@@ -70,6 +75,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return list, nil
case C.TF_ATTR_INT:
+ if meta.list_size == 0 {
+ return []int64(nil), nil
+ }
list := make([]C.int64_t, meta.list_size)
C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -82,6 +90,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_FLOAT:
+ if meta.list_size == 0 {
+ return []float32(nil), nil
+ }
list := make([]C.float, meta.list_size)
C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -94,6 +105,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_BOOL:
+ if meta.list_size == 0 {
+ return []bool(nil), nil
+ }
list := make([]C.uchar, meta.list_size)
C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -106,6 +120,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_TYPE:
+ if meta.list_size == 0 {
+ return []DataType(nil), nil
+ }
list := make([]C.TF_DataType, meta.list_size)
C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -118,6 +135,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_TENSOR:
+ if meta.list_size == 0 {
+ return []*Tensor(nil), nil
+ }
list := make([]*C.TF_Tensor, meta.list_size)
C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -130,9 +150,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_SHAPE:
+ if meta.list_size == 0 {
+ return []Shape(nil), nil
+ }
dims := make([]*C.int64_t, meta.list_size)
numDims := make([]C.int, meta.list_size)
- storage := make([]C.int64_t, meta.total_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.int64_t, meta.total_size+1)
C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
@@ -161,6 +185,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte
switch meta._type {
case C.TF_ATTR_STRING:
+ if meta.total_size == 0 {
+ return "", nil
+ }
v := make([]C.char, meta.total_size)
C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
@@ -202,6 +229,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte
if numDims < 0 {
return Shape{}, nil
}
+ if numDims == 0 {
+ return ScalarShape(), nil
+ }
dims := make([]C.int64_t, numDims)
C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
if err := status.Err(); err != nil {
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go
index 18fc0de90a..35b0cb352e 100644
--- a/tensorflow/go/attrs_test.go
+++ b/tensorflow/go/attrs_test.go
@@ -17,31 +17,175 @@ limitations under the License.
package tensorflow
import (
+ "fmt"
"reflect"
"testing"
)
func TestOperationAttrs(t *testing.T) {
- attrs := map[string]interface{}{
- "dtype": Float,
+ g := NewGraph()
+
+ i := 0
+ makeConst := func(v interface{}) Output {
+ op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
+ i += 1
+ if err != nil {
+ t.Fatal(err)
+ }
+ return op
}
- g := NewGraph()
- op, err := g.AddOperation(OpSpec{
- Type: "Placeholder",
- Name: "placeholder",
- Attrs: attrs,
- })
- if err != nil {
- t.Fatal(err)
+ makeTensor := func(v interface{}) *Tensor {
+ tensor, err := NewTensor(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return tensor
}
- for key, want := range attrs {
- out, err := op.Attr(key)
+
+ cases := []OpSpec{
+ {
+ Name: "type",
+ Type: "Placeholder",
+ Attrs: map[string]interface{}{
+ "dtype": Float,
+ },
+ },
+ {
+ Name: "list(float)",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{1, 2, 3, 4}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32{0, 1, 2, 3, 4, 5},
+ },
+ },
+ {
+ Name: "list(float) empty",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32(nil),
+ },
+ },
+ {
+ Name: "list(type),list(shape)",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst(float32(1)),
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Float, Int32},
+ "shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
+ },
+ },
+ {
+ Name: "list(type),list(shape) empty",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Int32},
+ "shapes": []Shape(nil),
+ },
+ },
+ {
+ Name: "list(type) empty,string empty,int",
+ Type: "_XlaSendFromHost",
+ Input: []Input{
+ OutputList([]Output{}),
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "Tinputs": []DataType(nil),
+ "key": "",
+ "device_ordinal": int64(0),
+ },
+ },
+ {
+ Name: "list(int),int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": []int64{1, 2},
+ },
+ },
+ {
+ Name: "list(int) empty,int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": ([]int64)(nil),
+ },
+ },
+ {
+ Name: "list(string),type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": []string{"foo", "bar"},
+ },
+ },
+ {
+ Name: "list(string) empty,type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": ([]string)(nil),
+ },
+ },
+ {
+ Name: "tensor",
+ Type: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": String,
+ "value": makeTensor("foo"),
+ },
+ },
+ }
+
+ for i, spec := range cases {
+ op, err := g.AddOperation(spec)
if err != nil {
t.Fatal(err)
}
- if !reflect.DeepEqual(out, want) {
- t.Fatalf("%q: Got %+v, wanted %+v", key, out, want)
+ for key, want := range spec.Attrs {
+ out, err := op.Attr(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(out, want) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
+ }
+ wantT, ok := want.(*Tensor)
+ if ok {
+ wantVal := wantT.Value()
+ outVal := out.(*Tensor).Value()
+ if !reflect.DeepEqual(outVal, wantVal) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
+ }
+ }
}
}
}
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index baaac41f4e..25ec718703 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -131,6 +131,9 @@ func (p Output) canBeAnInput() {}
// Consumers returns the inputs that consume this output.
func (p Output) Consumers() []Consumer {
max := int(C.TF_OperationOutputNumConsumers(p.c()))
+ if max == 0 {
+ return nil
+ }
inputs := make([]C.TF_Input, max)
n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
inputs = inputs[:int(n)]
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 0672e8ecc7..06b65bdfb7 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -222,6 +222,10 @@ func TestOperationConsumers(t *testing.T) {
t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
}
}
+
+ if len(b.Consumers()) != 0 {
+ t.Fatalf("expected %+v to have no consumers", b)
+ }
}
func forceGC() {