aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-11-29 11:18:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-29 11:22:15 -0800
commit78a4873cfa4562cf071492636f03e13fcb188bd8 (patch)
treef7b568c9fe34033d21a718c251a752fe10db91eb /tensorflow/go/graph.go
parentc572bc4fd7c73f4b8014ae43cdf9da5b99592f59 (diff)
Go: Bugfix: Make list-of-shape attributes in an operation work.
By respecting cgo rules on pointers. Without the change to graph.go, the newly added test would fail with: panic: runtime error: cgo argument has Go pointer to Go pointer in the call to the C function TF_SetAttrShapeList. Fixes #14891 PiperOrigin-RevId: 177336663
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r--tensorflow/go/graph.go64
1 files changed, 39 insertions, 25 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 46c600eab1..f200a8e00a 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -20,6 +20,24 @@ package tensorflow
//
// #include <stdlib.h>
// #include <string.h>
+//
+// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
+// const char* attr_name,
+// const int64_t* flat_dims,
+// const int* num_dims,
+// int num_shapes) {
+// const int64_t** dims =
+// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
+// for (int i = 0; i < num_shapes; i++) {
+// dims[i] = flat_dims;
+// if (num_dims[i] > 0) {
+// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
+// flat_dims += num_dims[i];
+// }
+// }
+// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
+// free(dims);
+// }
import "C"
import (
@@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
case Shape:
- ndims, dims := cshape(value)
+ ndims := C.int(value.NumDimensions())
var dimsp *C.int64_t
if ndims > 0 {
+ dims := make([]C.int64_t, ndims)
+ for i, d := range value.dims {
+ dims[i] = C.int64_t(d)
+ }
dimsp = &dims[0]
}
C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
case []Shape:
- ndims := make([]C.int, len(value))
- dims := make([][]C.int64_t, len(value))
- dimsp := make([]*C.int64_t, len(value))
- for i, s := range value {
- ndims[i], dims[i] = cshape(s)
- if ndims[i] > 0 {
- dimsp[i] = &dims[i][0]
- }
- }
- if len(value) > 0 {
- C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
- } else {
+ if len(value) == 0 {
C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
+ } else {
+ var flatDims []C.int64_t
+ ndims := make([]C.int, len(value))
+ for i, s := range value {
+ nd := s.NumDimensions()
+ ndims[i] = C.int(nd)
+ for _, d := range s.dims {
+ flatDims = append(flatDims, C.int64_t(d))
+ }
+ }
+ var flatDimsp *C.int64_t
+ if len(flatDims) > 0 {
+ flatDimsp = &flatDims[0]
+ }
+ C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
}
default:
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}
return nil
}
-
-func cshape(s Shape) (C.int, []C.int64_t) {
- ndims := C.int(s.NumDimensions())
- if ndims < 0 {
- return -1, nil
- }
- dims := make([]C.int64_t, ndims)
- for i, s := range s.dims {
- dims[i] = C.int64_t(s)
- }
- return ndims, dims
-}