aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/attrs.go
diff options
context:
space:
mode:
authorGravatar Tristan Rice <rice@fn.lc>2018-06-18 12:43:51 -0700
committerGravatar Tristan Rice <rice@fn.lc>2018-06-19 14:50:20 -0700
commit577b256460dfca4e7c429437dded48e76715fee7 (patch)
tree60cae79034b1a5ea59935f87d0b3ea766e3a516b /tensorflow/go/attrs.go
parentff7e6399443615675a3f1182c4f2e1850008da04 (diff)
tensorflow/go: add tests for zero length arrays passed to C
Diffstat (limited to 'tensorflow/go/attrs.go')
-rw-r--r--tensorflow/go/attrs.go36
1 files changed, 33 insertions, 3 deletions
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go
index bfa60d2aa8..f86c5737bc 100644
--- a/tensorflow/go/attrs.go
+++ b/tensorflow/go/attrs.go
@@ -33,7 +33,8 @@ func makeCShape(shape []C.int64_t) Shape {
return s
}
-// Attr returns the value of an attribute on op.
+// Attr returns the value of an attribute on op. It returns an error if the
+// attribute does not exist.
func (op *Operation) Attr(name string) (interface{}, error) {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
@@ -55,9 +56,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
switch meta._type {
case C.TF_ATTR_STRING:
+ if meta.list_size == 0 {
+ return []string(nil), nil
+ }
values := make([]unsafe.Pointer, meta.list_size)
lengths := make([]C.size_t, meta.list_size)
- storage := make([]C.char, meta.total_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.char, meta.total_size+1)
C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
@@ -70,6 +75,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return list, nil
case C.TF_ATTR_INT:
+ if meta.list_size == 0 {
+ return []int64(nil), nil
+ }
list := make([]C.int64_t, meta.list_size)
C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -82,6 +90,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_FLOAT:
+ if meta.list_size == 0 {
+ return []float32(nil), nil
+ }
list := make([]C.float, meta.list_size)
C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -94,6 +105,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_BOOL:
+ if meta.list_size == 0 {
+ return []bool(nil), nil
+ }
list := make([]C.uchar, meta.list_size)
C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -106,6 +120,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_TYPE:
+ if meta.list_size == 0 {
+ return []DataType(nil), nil
+ }
list := make([]C.TF_DataType, meta.list_size)
C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -118,6 +135,9 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_TENSOR:
+ if meta.list_size == 0 {
+ return []*Tensor(nil), nil
+ }
list := make([]*C.TF_Tensor, meta.list_size)
C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
if err := status.Err(); err != nil {
@@ -130,9 +150,13 @@ func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interf
return vals, nil
case C.TF_ATTR_SHAPE:
+ if meta.list_size == 0 {
+ return []Shape(nil), nil
+ }
dims := make([]*C.int64_t, meta.list_size)
numDims := make([]C.int, meta.list_size)
- storage := make([]C.int64_t, meta.total_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.int64_t, meta.total_size+1)
C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
if err := status.Err(); err != nil {
return nil, err
@@ -161,6 +185,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte
switch meta._type {
case C.TF_ATTR_STRING:
+ if meta.total_size == 0 {
+ return "", nil
+ }
v := make([]C.char, meta.total_size)
C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
if err := status.Err(); err != nil {
@@ -202,6 +229,9 @@ func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (inte
if numDims < 0 {
return Shape{}, nil
}
+ if numDims == 0 {
+ return ScalarShape(), nil
+ }
dims := make([]C.int64_t, numDims)
C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
if err := status.Err(); err != nil {