From 4ddcd6999a68335daf225fbd170d70f3d733b74f Mon Sep 17 00:00:00 2001 From: Asim Shankar Date: Wed, 11 Jul 2018 18:17:25 -0700 Subject: [Go]: Support device annotations when constructing graphs. PiperOrigin-RevId: 204225504 --- tensorflow/go/op/scope.go | 31 +++++++++++++++++++++++++++---- tensorflow/go/op/scope_test.go | 15 +++++++++++++++ 2 files changed, 42 insertions(+), 4 deletions(-) (limited to 'tensorflow/go/op') diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index 13de4294dc..ac39808d83 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -37,6 +37,7 @@ type Scope struct { namemap map[string]int namespace string controlDependencies []*tf.Operation + device string err *scopeErr } @@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation { args.Name = s.namespace + "/" + args.Name } args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...) + args.Device = s.device op, err := s.graph.AddOperation(args) if err != nil { s.UpdateErr(args.Type, err) @@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope { namespace = s.namespace + "/" + namespace } return &Scope{ - graph: s.graph, - namemap: make(map[string]int), - namespace: namespace, - err: s.err, + graph: s.graph, + namemap: make(map[string]int), + namespace: namespace, + controlDependencies: s.controlDependencies, + device: s.device, + err: s.err, } } @@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope { namemap: s.namemap, namespace: s.namespace, controlDependencies: deps, + device: s.device, + err: s.err, + } +} + +// WithDevice returns a new Scope which will cause all operations added to the +// graph to execute on devices that match the provided device specification. +// +// For example, WithDevice("/device:GPU:0") will cause operations added to +// the graph to execute on GPU #0. +// +// An empty string removes any device restrictions. +func (s *Scope) WithDevice(device string) *Scope { + return &Scope{ + graph: s.graph, + namemap: s.namemap, + namespace: s.namespace, + controlDependencies: s.controlDependencies, + device: device, err: s.err, } } diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index b58a61de98..be7b0ad892 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) { } } +func TestDevice(t *testing.T) { + s := NewScope() + matrix := Const(s, [][]float32{{3.0}}) + s = s.WithDevice("/device:GPU:0") + square := MatMul(s.SubScope("square"), matrix, matrix) + s = s.WithDevice("") + cube := MatMul(s.SubScope("cube"), square, matrix) + if got, want := square.Op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } + if got, want := cube.Op.Device(), ""; got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + func TestScopeFinalize(t *testing.T) { var ( root = NewScope() -- cgit v1.2.3