aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/op
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-11 18:17:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 18:21:08 -0700
commit4ddcd6999a68335daf225fbd170d70f3d733b74f (patch)
tree8d09e75cdd724842c1555d9ab887a5f277e9a38b /tensorflow/go/op
parent5574d6041a5a5d91c4be3449d7a456a146da4c0e (diff)
[Go]: Support device annotations when constructing graphs.
PiperOrigin-RevId: 204225504
Diffstat (limited to 'tensorflow/go/op')
-rw-r--r--tensorflow/go/op/scope.go31
-rw-r--r--tensorflow/go/op/scope_test.go15
2 files changed, 42 insertions, 4 deletions
diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go
index 13de4294dc..ac39808d83 100644
--- a/tensorflow/go/op/scope.go
+++ b/tensorflow/go/op/scope.go
@@ -37,6 +37,7 @@ type Scope struct {
namemap map[string]int
namespace string
controlDependencies []*tf.Operation
+ device string
err *scopeErr
}
@@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
args.Name = s.namespace + "/" + args.Name
}
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
+ args.Device = s.device
op, err := s.graph.AddOperation(args)
if err != nil {
s.UpdateErr(args.Type, err)
@@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
namespace = s.namespace + "/" + namespace
}
return &Scope{
- graph: s.graph,
- namemap: make(map[string]int),
- namespace: namespace,
- err: s.err,
+ graph: s.graph,
+ namemap: make(map[string]int),
+ namespace: namespace,
+ controlDependencies: s.controlDependencies,
+ device: s.device,
+ err: s.err,
}
}
@@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: deps,
+ device: s.device,
+ err: s.err,
+ }
+}
+
+// WithDevice returns a new Scope which will cause all operations added to the
+// graph to execute on devices that match the provided device specification.
+//
+// For example, WithDevice("/device:GPU:0") will cause operations added to
+// the graph to execute on GPU #0.
+//
+// An empty string removes any device restrictions.
+func (s *Scope) WithDevice(device string) *Scope {
+ return &Scope{
+ graph: s.graph,
+ namemap: s.namemap,
+ namespace: s.namespace,
+ controlDependencies: s.controlDependencies,
+ device: device,
err: s.err,
}
}
diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go
index b58a61de98..be7b0ad892 100644
--- a/tensorflow/go/op/scope_test.go
+++ b/tensorflow/go/op/scope_test.go
@@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
}
}
+func TestDevice(t *testing.T) {
+ s := NewScope()
+ matrix := Const(s, [][]float32{{3.0}})
+ s = s.WithDevice("/device:GPU:0")
+ square := MatMul(s.SubScope("square"), matrix, matrix)
+ s = s.WithDevice("")
+ cube := MatMul(s.SubScope("cube"), square, matrix)
+ if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+ if got, want := cube.Op.Device(), ""; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func TestScopeFinalize(t *testing.T) {
var (
root = NewScope()