aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph.go
diff options
context:
space:
mode:
authorGravatar Dandelion Mané <dandelion@google.com>2017-03-10 14:43:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-10 15:18:15 -0800
commit0386a01ad3beb28364599d82199be1c0837b3fa9 (patch)
tree3a1d2ef947a7bf37286efc0e8ff760e0401ab319 /tensorflow/go/graph.go
parente73ceaebb209a1e577e7240fba41c692c89143d0 (diff)
Merge changes from github.
Change: 149800363
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r--tensorflow/go/graph.go47
1 files changed, 39 insertions, 8 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index c0f91ffb30..c64ba84432 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -162,7 +162,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
for i, v := range in {
list[i] = v.c()
}
- C.TF_AddInputList(cdesc, &list[0], C.int(size))
+ if size > 0 {
+ C.TF_AddInputList(cdesc, &list[0], C.int(size))
+ } else {
+ C.TF_AddInputList(cdesc, nil, 0)
+ }
}
}
status := newStatus()
@@ -202,7 +206,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
list[i] = unsafe.Pointer(C.CString(s))
lens[i] = C.size_t(len(s))
}
- C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
+ if size > 0 {
+ C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
+ } else {
+ C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0)
+ }
for _, s := range list {
C.free(s)
}
@@ -214,7 +222,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
for i, v := range value {
list[i] = C.int64_t(v)
}
- C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
+ if size > 0 {
+ C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
+ } else {
+ C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0)
+ }
case float32:
C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
case []float32:
@@ -223,7 +235,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
for i, v := range value {
list[i] = C.float(v)
}
- C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
+ if size > 0 {
+ C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
+ } else {
+ C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0)
+ }
case bool:
v := C.uchar(0)
if value {
@@ -238,11 +254,18 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
list[i] = 1
}
}
- C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
+ if size > 0 {
+ C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
+ } else {
+ C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0)
+ }
case DataType:
C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
case []DataType:
- list := (*C.TF_DataType)(&value[0])
+ var list *C.TF_DataType
+ if len(value) > 0 {
+ list = (*C.TF_DataType)(&value[0])
+ }
C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
case *Tensor:
C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c)
@@ -255,7 +278,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
for i, v := range value {
list[i] = v.c
}
- C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c)
+ var plist **C.TF_Tensor
+ if size > 0 {
+ plist = &list[0]
+ }
+ C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c)
if err := status.Err(); err != nil {
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
@@ -276,7 +303,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
dimsp[i] = &dims[i][0]
}
}
- C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
+ if len(value) > 0 {
+ C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
+ } else {
+ C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
+ }
default:
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}