aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-11 18:17:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 18:21:08 -0700
commit4ddcd6999a68335daf225fbd170d70f3d733b74f (patch)
tree8d09e75cdd724842c1555d9ab887a5f277e9a38b /tensorflow/go
parent5574d6041a5a5d91c4be3449d7a456a146da4c0e (diff)
[Go]: Support device annotations when constructing graphs.
PiperOrigin-RevId: 204225504
Diffstat (limited to 'tensorflow/go')
-rw-r--r--tensorflow/go/graph.go14
-rw-r--r--tensorflow/go/op/scope.go31
-rw-r--r--tensorflow/go/op/scope_test.go15
-rw-r--r--tensorflow/go/operation.go6
-rw-r--r--tensorflow/go/operation_test.go23
5 files changed, 84 insertions, 5 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 08943a527c..32a77550ee 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -177,7 +177,14 @@ type OpSpec struct {
// being added.
ControlDependencies []*Operation
- // Other possible fields: Device, ColocateWith.
+ // The device on which the operation should be executed.
+ // If omitted, an appropriate device will automatically be selected.
+ //
+ // For example, if set of "/device:GPU:0", then the operation will
+ // execute on GPU #0.
+ Device string
+
+ // Other possible fields: ColocateWith.
}
// AddOperation adds an operation to g.
@@ -225,6 +232,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
return nil, fmt.Errorf("%v (memory will be leaked)", err)
}
}
+ if len(args.Device) > 0 {
+ cdevice := C.CString(args.Device)
+ C.TF_SetDevice(cdesc, cdevice)
+ C.free(unsafe.Pointer(cdevice))
+ }
c := C.TF_FinishOperation(cdesc, status.c)
if err := status.Err(); err != nil {
return nil, err
diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go
index 13de4294dc..ac39808d83 100644
--- a/tensorflow/go/op/scope.go
+++ b/tensorflow/go/op/scope.go
@@ -37,6 +37,7 @@ type Scope struct {
namemap map[string]int
namespace string
controlDependencies []*tf.Operation
+ device string
err *scopeErr
}
@@ -82,6 +83,7 @@ func (s *Scope) AddOperation(args tf.OpSpec) *tf.Operation {
args.Name = s.namespace + "/" + args.Name
}
args.ControlDependencies = append(args.ControlDependencies, s.controlDependencies...)
+ args.Device = s.device
op, err := s.graph.AddOperation(args)
if err != nil {
s.UpdateErr(args.Type, err)
@@ -98,10 +100,12 @@ func (s *Scope) SubScope(namespace string) *Scope {
namespace = s.namespace + "/" + namespace
}
return &Scope{
- graph: s.graph,
- namemap: make(map[string]int),
- namespace: namespace,
- err: s.err,
+ graph: s.graph,
+ namemap: make(map[string]int),
+ namespace: namespace,
+ controlDependencies: s.controlDependencies,
+ device: s.device,
+ err: s.err,
}
}
@@ -123,6 +127,25 @@ func (s *Scope) WithControlDependencies(ops ...*tf.Operation) *Scope {
namemap: s.namemap,
namespace: s.namespace,
controlDependencies: deps,
+ device: s.device,
+ err: s.err,
+ }
+}
+
+// WithDevice returns a new Scope which will cause all operations added to the
+// graph to execute on devices that match the provided device specification.
+//
+// For example, WithDevice("/device:GPU:0") will cause operations added to
+// the graph to execute on GPU #0.
+//
+// An empty string removes any device restrictions.
+func (s *Scope) WithDevice(device string) *Scope {
+ return &Scope{
+ graph: s.graph,
+ namemap: s.namemap,
+ namespace: s.namespace,
+ controlDependencies: s.controlDependencies,
+ device: device,
err: s.err,
}
}
diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go
index b58a61de98..be7b0ad892 100644
--- a/tensorflow/go/op/scope_test.go
+++ b/tensorflow/go/op/scope_test.go
@@ -112,6 +112,21 @@ func TestControlDependencies(t *testing.T) {
}
}
+func TestDevice(t *testing.T) {
+ s := NewScope()
+ matrix := Const(s, [][]float32{{3.0}})
+ s = s.WithDevice("/device:GPU:0")
+ square := MatMul(s.SubScope("square"), matrix, matrix)
+ s = s.WithDevice("")
+ cube := MatMul(s.SubScope("cube"), square, matrix)
+ if got, want := square.Op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+ if got, want := cube.Op.Device(), ""; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func TestScopeFinalize(t *testing.T) {
var (
root = NewScope()
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 25ec718703..d6a37e0a86 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -45,6 +45,12 @@ func (op *Operation) NumOutputs() int {
return int(C.TF_OperationNumOutputs(op.c))
}
+// Device returns a specification of the device on which this operation
+// will be executed, or the empty string if there is no such specification.
+func (op *Operation) Device() string {
+ return C.GoString(C.TF_OperationDevice(op.c))
+}
+
// OutputListSize returns the size of the list of Outputs that is produced by a
// named output of op.
//
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 06b65bdfb7..4af9e33ad0 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
}
}
+func TestOperationDevice(t *testing.T) {
+ graph := NewGraph()
+ v, err := NewTensor(float32(1.0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ op, err := graph.AddOperation(OpSpec{
+ Type: "Const",
+ Name: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": v.DataType(),
+ "value": v,
+ },
+ Device: "/device:GPU:0",
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)