aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-01-17 14:04:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-17 14:31:05 -0800
commitf8d75baaf4c92e76837de6bb64adccf8127a21d6 (patch)
treed3328471307fc541a14a95e8878d5cd9d32f24ab /tensorflow/go/graph.go
parent0662eabf9d6d670bd9a741ea3a3eb0c9f0005850 (diff)
Go: Support setting shape valued attributes.
Fixes #6833 Change: 144752893
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r--tensorflow/go/graph.go35
1 files changed, 30 insertions, 5 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 2eb1194610..c0f91ffb30 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -259,13 +259,38 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
if err := status.Err(); err != nil {
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
+ case Shape:
+ ndims, dims := cshape(value)
+ var dimsp *C.int64_t
+ if ndims > 0 {
+ dimsp = &dims[0]
+ }
+ C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
+ case []Shape:
+ ndims := make([]C.int, len(value))
+ dims := make([][]C.int64_t, len(value))
+ dimsp := make([]*C.int64_t, len(value))
+ for i, s := range value {
+ ndims[i], dims[i] = cshape(s)
+ if ndims[i] > 0 {
+ dimsp[i] = &dims[i][0]
+ }
+ }
+ C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
default:
- // Shapes can be done, but will require that it be
- // distinguishable from []int64. Which is fine, it
- // probably makes sense to define a Shape type anyway,
- // since that should handle partially known shapes as
- // well and hide the special meaning of -1?
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}
return nil
}
+
+func cshape(s Shape) (C.int, []C.int64_t) {
+ ndims := C.int(s.NumDimensions())
+ if ndims < 0 {
+ return -1, nil
+ }
+ dims := make([]C.int64_t, ndims)
+ for i, s := range s.dims {
+ dims[i] = C.int64_t(s)
+ }
+ return ndims, dims
+}