diff options
author | 2018-06-18 12:43:51 -0700 | |
---|---|---|
committer | 2018-06-19 14:50:20 -0700 | |
commit | 577b256460dfca4e7c429437dded48e76715fee7 (patch) | |
tree | 60cae79034b1a5ea59935f87d0b3ea766e3a516b /tensorflow/go | |
parent | ff7e6399443615675a3f1182c4f2e1850008da04 (diff) |
tensorflow/go: add tests for zero length arrays passed to C
Diffstat (limited to 'tensorflow/go')
-rw-r--r-- | tensorflow/go/attrs.go | 36 | ||||
-rw-r--r-- | tensorflow/go/attrs_test.go | 172 | ||||
-rw-r--r-- | tensorflow/go/operation.go | 3 | ||||
-rw-r--r-- | tensorflow/go/operation_test.go | 4 |
4 files changed, 198 insertions, 17 deletions
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go index bfa60d2aa8..f86c5737bc 100644 --- a/tensorflow/go/attrs.go +++ b/tensorflow/go/attrs.go @@ -33,7 +33,8 @@ func makeCShape(shape []C.int64_t) Shape { return s } -// Attr returns the value of an attribute on op. +// Attr returns the value of an attribute on op. It returns an error if the +// attribute does not exist. func (op *Operation) Attr(name string) (interface{}, error) { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) @@ -55,9 +56,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf switch meta._type { case C.TF_ATTR_STRING: + if meta.list_size == 0 { + return []string(nil), nil + } values := make([]unsafe.Pointer, meta.list_size) lengths := make([]C.size_t, meta.list_size) - storage := make([]C.char, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.char, meta.total_size+1) C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -70,6 +75,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return list, nil case C.TF_ATTR_INT: + if meta.list_size == 0 { + return []int64(nil), nil + } list := make([]C.int64_t, meta.list_size) C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -82,6 +90,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_FLOAT: + if meta.list_size == 0 { + return []float32(nil), nil + } list := make([]C.float, meta.list_size) C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -94,6 +105,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_BOOL: + if meta.list_size == 0 { + return []bool(nil), nil + } list := make([]C.uchar, meta.list_size) C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -106,6 +120,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TYPE: + if meta.list_size == 0 { + return []DataType(nil), nil + } list := make([]C.TF_DataType, meta.list_size) C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -118,6 +135,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TENSOR: + if meta.list_size == 0 { + return []*Tensor(nil), nil + } list := make([]*C.TF_Tensor, meta.list_size) C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -130,9 +150,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_SHAPE: + if meta.list_size == 0 { + return []Shape(nil), nil + } dims := make([]*C.int64_t, meta.list_size) numDims := make([]C.int, meta.list_size) - storage := make([]C.int64_t, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.int64_t, meta.total_size+1) C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -161,6 +185,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte switch meta._type { case C.TF_ATTR_STRING: + if meta.total_size == 0 { + return "", nil + } v := make([]C.char, meta.total_size) C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { @@ -202,6 +229,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte if numDims < 0 { return Shape{}, nil } + if numDims == 0 { + return ScalarShape(), nil + } dims := make([]C.int64_t, numDims) C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c) if err := status.Err(); err != nil { diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go index 18fc0de90a..35b0cb352e 100644 --- a/tensorflow/go/attrs_test.go +++ b/tensorflow/go/attrs_test.go @@ -17,31 +17,175 @@ limitations under the License. package tensorflow import ( + "fmt" "reflect" "testing" ) func TestOperationAttrs(t *testing.T) { - attrs := map[string]interface{}{ - "dtype": Float, + g := NewGraph() + + i := 0 + makeConst := func(v interface{}) Output { + op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) + i += 1 + if err != nil { + t.Fatal(err) + } + return op } - g := NewGraph() - op, err := g.AddOperation(OpSpec{ - Type: "Placeholder", - Name: "placeholder", - Attrs: attrs, - }) - if err != nil { - t.Fatal(err) + makeTensor := func(v interface{}) *Tensor { + tensor, err := NewTensor(v) + if err != nil { + t.Fatal(err) + } + return tensor } - for key, want := range attrs { - out, err := op.Attr(key) + + cases := []OpSpec{ + { + Name: "type", + Type: "Placeholder", + Attrs: map[string]interface{}{ + "dtype": Float, + }, + }, + { + Name: "list(float)", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{1, 2, 3, 4}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32{0, 1, 2, 3, 4, 5}, + }, + }, + { + Name: "list(float) empty", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32(nil), + }, + }, + { + Name: "list(type),list(shape)", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst(float32(1)), + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Float, Int32}, + "shapes": []Shape{ScalarShape(), MakeShape(1, 1)}, + }, + }, + { + Name: "list(type),list(shape) empty", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Int32}, + "shapes": []Shape(nil), + }, + }, + { + Name: "list(type) empty,string empty,int", + Type: "_XlaSendFromHost", + Input: []Input{ + OutputList([]Output{}), + makeConst(""), + }, + Attrs: map[string]interface{}{ + "Tinputs": []DataType(nil), + "key": "", + "device_ordinal": int64(0), + }, + }, + { + Name: "list(int),int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": []int64{1, 2}, + }, + }, + { + Name: "list(int) empty,int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": ([]int64)(nil), + }, + }, + { + Name: "list(string),type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": []string{"foo", "bar"}, + }, + }, + { + Name: "list(string) empty,type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": ([]string)(nil), + }, + }, + { + Name: "tensor", + Type: "Const", + Attrs: map[string]interface{}{ + "dtype": String, + "value": makeTensor("foo"), + }, + }, + } + + for i, spec := range cases { + op, err := g.AddOperation(spec) if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(out, want) { - t.Fatalf("%q: Got %+v, wanted %+v", key, out, want) + for key, want := range spec.Attrs { + out, err := op.Attr(key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out, want) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want) + } + wantT, ok := want.(*Tensor) + if ok { + wantVal := wantT.Value() + outVal := out.(*Tensor).Value() + if !reflect.DeepEqual(outVal, wantVal) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal) + } + } } } } diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index baaac41f4e..25ec718703 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -131,6 +131,9 @@ func (p Output) canBeAnInput() {} // Consumers returns the inputs that consume this output. func (p Output) Consumers() []Consumer { max := int(C.TF_OperationOutputNumConsumers(p.c())) + if max == 0 { + return nil + } inputs := make([]C.TF_Input, max) n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max)) inputs = inputs[:int(n)] diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 0672e8ecc7..06b65bdfb7 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -222,6 +222,10 @@ func TestOperationConsumers(t *testing.T) { t.Fatalf("%d. Got op name %q, wanted %q", i, got, want) } } + + if len(b.Consumers()) != 0 { + t.Fatalf("expected %+v to have no consumers", b) + } } func forceGC() { |