aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/attrs_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/go/attrs_test.go')
-rw-r--r--tensorflow/go/attrs_test.go193
1 files changed, 193 insertions, 0 deletions
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go
new file mode 100644
index 0000000000..ea8af221ae
--- /dev/null
+++ b/tensorflow/go/attrs_test.go
@@ -0,0 +1,193 @@
+/*
+Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package tensorflow
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestOperationAttrs(t *testing.T) {
+ g := NewGraph()
+
+ i := 0
+ makeConst := func(v interface{}) Output {
+ op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
+ i++
+ if err != nil {
+ t.Fatal(err)
+ }
+ return op
+ }
+
+ makeTensor := func(v interface{}) *Tensor {
+ tensor, err := NewTensor(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return tensor
+ }
+
+ cases := []OpSpec{
+ {
+ Name: "type",
+ Type: "Placeholder",
+ Attrs: map[string]interface{}{
+ "dtype": Float,
+ },
+ },
+ {
+ Name: "list(float)",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{1, 2, 3, 4}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32{0, 1, 2, 3, 4, 5},
+ },
+ },
+ {
+ Name: "list(float) empty",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32(nil),
+ },
+ },
+ /* TODO(ashankar): debug this issue and add it back later.
+ {
+ Name: "list(type),list(shape)",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst(float32(1)),
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Float, Int32},
+ "shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
+ },
+ },
+ {
+ Name: "list(type),list(shape) empty",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Int32},
+ "shapes": []Shape(nil),
+ },
+ },
+ {
+ Name: "list(type) empty,string empty,int",
+ Type: "_XlaSendFromHost",
+ Input: []Input{
+ OutputList([]Output{}),
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "Tinputs": []DataType(nil),
+ "key": "",
+ "device_ordinal": int64(0),
+ },
+ },
+ */
+ {
+ Name: "list(int),int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": []int64{1, 2},
+ },
+ },
+ {
+ Name: "list(int) empty,int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": ([]int64)(nil),
+ },
+ },
+ {
+ Name: "list(string),type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": []string{"foo", "bar"},
+ },
+ },
+ {
+ Name: "list(string) empty,type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": ([]string)(nil),
+ },
+ },
+ {
+ Name: "tensor",
+ Type: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": String,
+ "value": makeTensor("foo"),
+ },
+ },
+ }
+
+ for i, spec := range cases {
+ op, err := g.AddOperation(spec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for key, want := range spec.Attrs {
+ out, err := op.Attr(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(out, want) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
+ }
+ wantT, ok := want.(*Tensor)
+ if ok {
+ wantVal := wantT.Value()
+ outVal := out.(*Tensor).Value()
+ if !reflect.DeepEqual(outVal, wantVal) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
+ }
+ }
+ }
+ }
+}