diff options
author | Dandelion Mané <dandelion@google.com> | 2017-03-10 14:43:23 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-10 15:18:15 -0800 |
commit | 0386a01ad3beb28364599d82199be1c0837b3fa9 (patch) | |
tree | 3a1d2ef947a7bf37286efc0e8ff760e0401ab319 /tensorflow/go/graph.go | |
parent | e73ceaebb209a1e577e7240fba41c692c89143d0 (diff) |
Merge changes from github.
Change: 149800363
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 47 |
1 files changed, 39 insertions, 8 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index c0f91ffb30..c64ba84432 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -162,7 +162,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { for i, v := range in { list[i] = v.c() } - C.TF_AddInputList(cdesc, &list[0], C.int(size)) + if size > 0 { + C.TF_AddInputList(cdesc, &list[0], C.int(size)) + } else { + C.TF_AddInputList(cdesc, nil, 0) + } } } status := newStatus() @@ -202,7 +206,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu list[i] = unsafe.Pointer(C.CString(s)) lens[i] = C.size_t(len(s)) } - C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) + if size > 0 { + C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) + } else { + C.TF_SetAttrStringList(cdesc, cAttrName, nil, nil, 0) + } for _, s := range list { C.free(s) } @@ -214,7 +222,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu for i, v := range value { list[i] = C.int64_t(v) } - C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) + if size > 0 { + C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) + } else { + C.TF_SetAttrIntList(cdesc, cAttrName, nil, 0) + } case float32: C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) case []float32: @@ -223,7 +235,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu for i, v := range value { list[i] = C.float(v) } - C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) + if size > 0 { + C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) + } else { + C.TF_SetAttrFloatList(cdesc, cAttrName, nil, 0) + } case bool: v := C.uchar(0) if value { @@ -238,11 +254,18 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu list[i] = 1 } } - C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) + if size > 0 { + C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) + } else { + C.TF_SetAttrBoolList(cdesc, cAttrName, nil, 0) + } case DataType: C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) case []DataType: - list := (*C.TF_DataType)(&value[0]) + var list *C.TF_DataType + if len(value) > 0 { + list = (*C.TF_DataType)(&value[0]) + } C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) case *Tensor: C.TF_SetAttrTensor(cdesc, cAttrName, value.c, status.c) @@ -255,7 +278,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu for i, v := range value { list[i] = v.c } - C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c) + var plist **C.TF_Tensor + if size > 0 { + plist = &list[0] + } + C.TF_SetAttrTensorList(cdesc, cAttrName, plist, C.int(size), status.c) if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } @@ -276,7 +303,11 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu dimsp[i] = &dims[i][0] } } - C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) + if len(value) > 0 { + C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) + } else { + C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) + } default: return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } |