diff options
author | 2017-11-29 11:18:38 -0800 | |
---|---|---|
committer | 2017-11-29 11:22:15 -0800 | |
commit | 78a4873cfa4562cf071492636f03e13fcb188bd8 (patch) | |
tree | f7b568c9fe34033d21a718c251a752fe10db91eb /tensorflow/go/graph.go | |
parent | c572bc4fd7c73f4b8014ae43cdf9da5b99592f59 (diff) |
Go: Bugfix: Make list-of-shape attributes in an operation work.
By respecting cgo rules on pointers.
Without the change to graph.go, the newly added test would fail with:
panic: runtime error: cgo argument has Go pointer to Go pointer
in the call to the C function TF_SetAttrShapeList.
Fixes #14891
PiperOrigin-RevId: 177336663
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 64 |
1 files changed, 39 insertions, 25 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 46c600eab1..f200a8e00a 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -20,6 +20,24 @@ package tensorflow // // #include <stdlib.h> // #include <string.h> +// +// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc, +// const char* attr_name, +// const int64_t* flat_dims, +// const int* num_dims, +// int num_shapes) { +// const int64_t** dims = +// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes); +// for (int i = 0; i < num_shapes; i++) { +// dims[i] = flat_dims; +// if (num_dims[i] > 0) { +// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0. +// flat_dims += num_dims[i]; +// } +// } +// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes); +// free(dims); +// } import "C" import ( @@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu return fmt.Errorf("bad value for attribute %q: %v", name, err) } case Shape: - ndims, dims := cshape(value) + ndims := C.int(value.NumDimensions()) var dimsp *C.int64_t if ndims > 0 { + dims := make([]C.int64_t, ndims) + for i, d := range value.dims { + dims[i] = C.int64_t(d) + } dimsp = &dims[0] } C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) case []Shape: - ndims := make([]C.int, len(value)) - dims := make([][]C.int64_t, len(value)) - dimsp := make([]*C.int64_t, len(value)) - for i, s := range value { - ndims[i], dims[i] = cshape(s) - if ndims[i] > 0 { - dimsp[i] = &dims[i][0] - } - } - if len(value) > 0 { - C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) - } else { + if len(value) == 0 { C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0) + } else { + var flatDims []C.int64_t + ndims := make([]C.int, len(value)) + for i, s := range value { + nd := s.NumDimensions() + ndims[i] = C.int(nd) + for _, d := range s.dims { + flatDims = append(flatDims, C.int64_t(d)) + } + } + var flatDimsp *C.int64_t + if len(flatDims) > 0 { + flatDimsp = &flatDims[0] + } + C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value))) } default: return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil } - -func cshape(s Shape) (C.int, []C.int64_t) { - ndims := C.int(s.NumDimensions()) - if ndims < 0 { - return -1, nil - } - dims := make([]C.int64_t, ndims) - for i, s := range s.dims { - dims[i] = C.int64_t(s) - } - return ndims, dims -} |