From 4ddcd6999a68335daf225fbd170d70f3d733b74f Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 11 Jul 2018 18:17:25 -0700 Subject: [Go]: Support device annotations when constructing graphs. PiperOrigin-RevId: 204225504 --- tensorflow/go/graph.go | 14 +++++++++++++- tensorflow/go/op/scope.go | 31 +++++++++++++++++++++++++++---- tensorflow/go/op/scope_test.go | 15 +++++++++++++++ tensorflow/go/operation.go | 6 ++++++ tensorflow/go/operation_test.go | 23 +++++++++++++++++++++++ 5 files changed, 84 insertions(+), 5 deletions(-) (limited to 'tensorflow/go') diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index 08943a527c..32a77550ee 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -177,7 +177,14 @@ type OpSpec struct { // being added. ControlDependencies []*Operation - // Other possible fields: Device, ColocateWith. + // The device on which the operation should be executed. + // If omitted, an appropriate device will automatically be selected. + // + // For example, if set of "/device:GPU:0", then the operation will + // execute on GPU #0. + Device string + + // Other possible fields: ColocateWith. } // AddOperation adds an operation to g. @@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { return nil, fmt.Errorf("%v (memory will be leaked)", err) } } + if len(args.Device) > 0 { + cdevice := C.CString(args.Device) + C.TF_SetDevice(cdesc, cdevice) + C.free(unsafe.Pointer(cdevice)) + } c := C.TF_FinishOperation(cdesc, status.c) if err := status.Err(); err != nil { return nil, err diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index 13de4294dc..ac39808d83 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -37,6 +37,7 @@ type Scope struct { namemap map[string]int namespace string controlDependencies []*tf.Operation + device string err *scopeErr } @@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { args.Name = s.namespace + "/" + args.Name } args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) + args.Device = s.device op, err := s.graph.AddOperation(args) if err != nil { s.UpdateErr(args.Type, err) @@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope { namespace = s.namespace + "/" + namespace } return &Scope{ - graph: s.graph, - namemap: make(map[string]int), - namespace: namespace, - err: s.err, + graph: s.graph, + namemap: make(map[string]int), + namespace: namespace, + controlDependencies: s.controlDependencies, + device: s.device, + err: s.err, } } @@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { namemap: s.namemap, namespace: s.namespace, controlDependencies: deps, + device: s.device, + err: s.err, + } +} + +// WithDevice returns a new Scope which will cause all operations added to the +// graph to execute on devices that match the provided device specification. +// +// For example, WithDevice("/device:GPU:0") will cause operations added to +// the graph to execute on GPU #0. +// +// An empty string removes any device restrictions. +func (s *Scope) WithDevice(device string) *Scope { + return &Scope{ + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: s.controlDependencies, + device: device, err: s.err, } } diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index b58a61de98..be7b0ad892 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) { } } +func TestDevice(t *testing.T) { + s := NewScope() + matrix := Const(s, [][]float32{{3.0}}) + s = s.WithDevice("/device:GPU:0") + square := MatMul(s.SubScope("square"), matrix, matrix) + s = s.WithDevice("") + cube := MatMul(s.SubScope("cube"), square, matrix) + if got, want := square.Op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } + if got, want := cube.Op.Device(), ""; got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + func TestScopeFinalize(t *testing.T) { var ( root = NewScope() diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 25ec718703..d6a37e0a86 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int { return int(C.TF_OperationNumOutputs(op.c)) } +// Device returns a specification of the device on which this operation +// will be executed, or the empty string if there is no such specification. +func (op *Operation) Device() string { + return C.GoString(C.TF_OperationDevice(op.c)) +} + // OutputListSize returns the size of the list of Outputs that is produced by a // named output of op. // diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 06b65bdfb7..4af9e33ad0 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) { } } +func TestOperationDevice(t *testing.T) { + graph := NewGraph() + v, err := NewTensor(float32(1.0)) + if err != nil { + t.Fatal(err) + } + op, err := graph.AddOperation(OpSpec{ + Type: "Const", + Name: "Const", + Attrs: map[string]interface{}{ + "dtype": v.DataType(), + "value": v, + }, + Device: "/device:GPU:0", + }) + if err != nil { + t.Fatal(err) + } + if got, want := op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + func forceGC() { var mem runtime.MemStats runtime.ReadMemStats(&mem) -- cgit v1.2.3