diff options
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 08943a527c..32a77550ee 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -177,7 +177,14 @@ type OpSpec struct { // being added. ControlDependencies []*Operation - // Other possible fields: Device, ColocateWith. + // The device on which the operation should be executed. + // If omitted, an appropriate device will automatically be selected. + // + // For example, if set of "/device:GPU:0", then the operation will + // execute on GPU #0. + Device string + + // Other possible fields: ColocateWith. } // AddOperation adds an operation to g. @@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { return nil, fmt.Errorf("%v (memory will be leaked)", err) } } + if len(args.Device) > 0 { + cdevice := C.CString(args.Device) + C.TF_SetDevice(cdesc, cdevice) + C.free(unsafe.Pointer(cdevice)) + } c := C.TF_FinishOperation(cdesc, status.c) if err := status.Err(); err != nil { return nil, err |