diff options
author | Tristan Rice <rice@fn.lc> | 2018-06-18 12:43:51 -0700 |
---|---|---|
committer | Tristan Rice <rice@fn.lc> | 2018-06-19 14:50:20 -0700 |
commit | 577b256460dfca4e7c429437dded48e76715fee7 (patch) | |
tree | 60cae79034b1a5ea59935f87d0b3ea766e3a516b /tensorflow/go/attrs.go | |
parent | ff7e6399443615675a3f1182c4f2e1850008da04 (diff) |
tensorflow/go: add tests for zero length arrays passed to C
Diffstat (limited to 'tensorflow/go/attrs.go')
-rw-r--r-- | tensorflow/go/attrs.go | 36 |
1 files changed, 33 insertions, 3 deletions
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go index bfa60d2aa8..f86c5737bc 100644 --- a/tensorflow/go/attrs.go +++ b/tensorflow/go/attrs.go @@ -33,7 +33,8 @@ func makeCShape(shape []C.int64_t) Shape { return s } -// Attr returns the value of an attribute on op. +// Attr returns the value of an attribute on op. It returns an error if the +// attribute does not exist. func (op *Operation) Attr(name string) (interface{}, error) { cname := C.CString(name) defer C.free(unsafe.Pointer(cname)) @@ -55,9 +56,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf switch meta._type { case C.TF_ATTR_STRING: + if meta.list_size == 0 { + return []string(nil), nil + } values := make([]unsafe.Pointer, meta.list_size) lengths := make([]C.size_t, meta.list_size) - storage := make([]C.char, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.char, meta.total_size+1) C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -70,6 +75,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return list, nil case C.TF_ATTR_INT: + if meta.list_size == 0 { + return []int64(nil), nil + } list := make([]C.int64_t, meta.list_size) C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -82,6 +90,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_FLOAT: + if meta.list_size == 0 { + return []float32(nil), nil + } list := make([]C.float, meta.list_size) C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -94,6 +105,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_BOOL: + if meta.list_size == 0 { + return []bool(nil), nil + } list := make([]C.uchar, meta.list_size) C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -106,6 +120,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TYPE: + if meta.list_size == 0 { + return []DataType(nil), nil + } list := make([]C.TF_DataType, meta.list_size) C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -118,6 +135,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_TENSOR: + if meta.list_size == 0 { + return []*Tensor(nil), nil + } list := make([]*C.TF_Tensor, meta.list_size) C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c) if err := status.Err(); err != nil { @@ -130,9 +150,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf return vals, nil case C.TF_ATTR_SHAPE: + if meta.list_size == 0 { + return []Shape(nil), nil + } dims := make([]*C.int64_t, meta.list_size) numDims := make([]C.int, meta.list_size) - storage := make([]C.int64_t, meta.total_size) + // Add one element in case total_size is zero. + storage := make([]C.int64_t, meta.total_size+1) C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c) if err := status.Err(); err != nil { return nil, err @@ -161,6 +185,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte switch meta._type { case C.TF_ATTR_STRING: + if meta.total_size == 0 { + return "", nil + } v := make([]C.char, meta.total_size) C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c) if err := status.Err(); err != nil { @@ -202,6 +229,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte if numDims < 0 { return Shape{}, nil } + if numDims == 0 { + return ScalarShape(), nil + } dims := make([]C.int64_t, numDims) C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c) if err := status.Err(); err != nil { |