aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/go/operation_test.go')
-rw-r--r--tensorflow/go/operation_test.go23
1 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 06b65bdfb7..4af9e33ad0 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -228,6 +228,29 @@ func TestOperationConsumers(t *testing.T) {
}
}
+func TestOperationDevice(t *testing.T) {
+ graph := NewGraph()
+ v, err := NewTensor(float32(1.0))
+ if err != nil {
+ t.Fatal(err)
+ }
+ op, err := graph.AddOperation(OpSpec{
+ Type: "Const",
+ Name: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": v.DataType(),
+ "value": v,
+ },
+ Device: "/device:GPU:0",
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := op.Device(), "/device:GPU:0"; got != want {
+ t.Errorf("Got %q, want %q", got, want)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)