diff options
author | Asim Shankar <ashankar@google.com> | 2017-01-17 14:04:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-17 14:31:05 -0800 |
commit | f8d75baaf4c92e76837de6bb64adccf8127a21d6 (patch) | |
tree | d3328471307fc541a14a95e8878d5cd9d32f24ab /tensorflow/go/graph.go | |
parent | 0662eabf9d6d670bd9a741ea3a3eb0c9f0005850 (diff) |
Go: Support setting shape valued attributes.
Fixes #6833
Change: 144752893
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 35 |
1 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 2eb1194610..c0f91ffb30 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -259,13 +259,38 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu if err := status.Err(); err != nil { return fmt.Errorf("bad value for attribute %q: %v", name, err) } + case Shape: + ndims, dims := cshape(value) + var dimsp *C.int64_t + if ndims > 0 { + dimsp = &dims[0] + } + C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims) + case []Shape: + ndims := make([]C.int, len(value)) + dims := make([][]C.int64_t, len(value)) + dimsp := make([]*C.int64_t, len(value)) + for i, s := range value { + ndims[i], dims[i] = cshape(s) + if ndims[i] > 0 { + dimsp[i] = &dims[i][0] + } + } + C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value))) default: - // Shapes can be done, but will require that it be - // distinguishable from []int64. Which is fine, it - // probably makes sense to define a Shape type anyway, - // since that should handle partially known shapes as - // well and hide the special meaning of -1? return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) } return nil } + +func cshape(s Shape) (C.int, []C.int64_t) { + ndims := C.int(s.NumDimensions()) + if ndims < 0 { + return -1, nil + } + dims := make([]C.int64_t, ndims) + for i, s := range s.dims { + dims[i] = C.int64_t(s) + } + return ndims, dims +} |