aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session_test.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-11-21 11:24:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-21 11:44:14 -0800
commitc80be35454a418c18b3fd57614bfcb5265274c33 (patch)
tree975b50379b742f7913eb6d2c14ab8fdfc552d22d /tensorflow/go/session_test.go
parent074acf38d83bf4be1e3fe2bb813d4bf32b97c2ac (diff)
Go: Support for String tensors.
And use this support to simplify the Inception example as it can use the DecodeJpeg op. Also fixed a bug in generated op functions - A TensorFlow "int" is a Go "int64". Another step in #10 Change: 139809489
Diffstat (limited to 'tensorflow/go/session_test.go')
-rw-r--r--tensorflow/go/session_test.go40
1 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 0d3660995b..14ecca402b 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -119,6 +119,46 @@ func TestSessionRunConcat(t *testing.T) {
}
}
+func TestSessionWithStringTensors(t *testing.T) {
+ // Construct the graph:
+ // AsString(StringToHashBucketFast("PleaseHashMe")) Will be much
+ // prettier if using the ops package, but in this package graphs are
+ // constructed from first principles.
+ var (
+ g = NewGraph()
+ feed, _ = Const(g, "input", "PleaseHashMe")
+ hash, _ = g.AddOperation(OpSpec{
+ Type: "StringToHashBucketFast",
+ Input: []Input{feed},
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(1 << 32),
+ },
+ })
+ str, _ = g.AddOperation(OpSpec{
+ Type: "AsString",
+ Input: []Input{hash.Output(0)},
+ })
+ )
+ s, err := NewSession(g, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ output, err := s.Run(nil, []Output{str.Output(0)}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(output) != 1 {
+ t.Fatal(len(output))
+ }
+ got, ok := output[0].Value().(string)
+ if !ok {
+ t.Fatalf("Got %T, wanted string", output[0].Value())
+ }
+ if want := "1027741475"; got != want {
+ t.Fatalf("Got %q, want %q", got, want)
+ }
+}
+
func TestConcurrency(t *testing.T) {
tensor, err := NewTensor(int64(1))
if err != nil {