diff options
author | Asim Shankar <ashankar@google.com> | 2016-11-21 11:24:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-21 11:44:14 -0800 |
commit | c80be35454a418c18b3fd57614bfcb5265274c33 (patch) | |
tree | 975b50379b742f7913eb6d2c14ab8fdfc552d22d /tensorflow/go/session_test.go | |
parent | 074acf38d83bf4be1e3fe2bb813d4bf32b97c2ac (diff) |
Go: Support for String tensors.
And use this support to simplify the Inception example as it can use the
DecodeJpeg op.
Also fixed a bug in generated op functions - A TensorFlow "int"
is a Go "int64".
Another step in #10
Change: 139809489
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r-- | tensorflow/go/session_test.go | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 0d3660995b..14ecca402b 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -119,6 +119,46 @@ func TestSessionRunConcat(t *testing.T) { } } +func TestSessionWithStringTensors(t *testing.T) { + // Construct the graph: + // AsString(StringToHashBucketFast("PleaseHashMe")) Will be much + // prettier if using the ops package, but in this package graphs are + // constructed from first principles. + var ( + g = NewGraph() + feed, _ = Const(g, "input", "PleaseHashMe") + hash, _ = g.AddOperation(OpSpec{ + Type: "StringToHashBucketFast", + Input: []Input{feed}, + Attrs: map[string]interface{}{ + "num_buckets": int64(1 << 32), + }, + }) + str, _ = g.AddOperation(OpSpec{ + Type: "AsString", + Input: []Input{hash.Output(0)}, + }) + ) + s, err := NewSession(g, nil) + if err != nil { + t.Fatal(err) + } + output, err := s.Run(nil, []Output{str.Output(0)}, nil) + if err != nil { + t.Fatal(err) + } + if len(output) != 1 { + t.Fatal(len(output)) + } + got, ok := output[0].Value().(string) + if !ok { + t.Fatalf("Got %T, wanted string", output[0].Value()) + } + if want := "1027741475"; got != want { + t.Fatalf("Got %q, want %q", got, want) + } +} + func TestConcurrency(t *testing.T) { tensor, err := NewTensor(int64(1)) if err != nil { |