diff options
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r-- | tensorflow/go/operation_test.go | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 06b65bdfb7..4af9e33ad0 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) { } } +func TestOperationDevice(t *testing.T) { + graph := NewGraph() + v, err := NewTensor(float32(1.0)) + if err != nil { + t.Fatal(err) + } + op, err := graph.AddOperation(OpSpec{ + Type: "Const", + Name: "Const", + Attrs: map[string]interface{}{ + "dtype": v.DataType(), + "value": v, + }, + Device: "/device:GPU:0", + }) + if err != nil { + t.Fatal(err) + } + if got, want := op.Device(), "/device:GPU:0"; got != want { + t.Errorf("Got %q, want %q", got, want) + } +} + func forceGC() { var mem runtime.MemStats runtime.ReadMemStats(&mem) |