aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2018-07-11 18:17:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-11 18:21:08 -0700
commit4ddcd6999a68335daf225fbd170d70f3d733b74f (patch)
tree8d09e75cdd724842c1555d9ab887a5f277e9a38b /tensorflow/go/operation_test.go
parent5574d6041a5a5d91c4be3449d7a456a146da4c0e (diff)
[Go]: Support device annotations when constructing graphs.
PiperOrigin-RevId: 204225504
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 06b65bdfb7..4af9e33ad0 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
}
}
+func TestOperationDevice(t *testing.T) {
+ graph := NewGraph()
+ v, err := NewTensor(float32(1.0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ op, err := graph.AddOperation(OpSpec{
+ Type: "Const",
+ Name: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": v.DataType(),
+ "value": v,
+ },
+ Device: "/device:GPU:0",
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)