diff options
Diffstat (limited to 'tensorflow/go/attrs_test.go')
-rw-r--r-- | tensorflow/go/attrs_test.go | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go new file mode 100644 index 0000000000..ea8af221ae --- /dev/null +++ b/tensorflow/go/attrs_test.go @@ -0,0 +1,193 @@ +/* +Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package tensorflow + +import ( + "fmt" + "reflect" + "testing" +) + +func TestOperationAttrs(t *testing.T) { + g := NewGraph() + + i := 0 + makeConst := func(v interface{}) Output { + op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v) + i++ + if err != nil { + t.Fatal(err) + } + return op + } + + makeTensor := func(v interface{}) *Tensor { + tensor, err := NewTensor(v) + if err != nil { + t.Fatal(err) + } + return tensor + } + + cases := []OpSpec{ + { + Name: "type", + Type: "Placeholder", + Attrs: map[string]interface{}{ + "dtype": Float, + }, + }, + { + Name: "list(float)", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{1, 2, 3, 4}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32{0, 1, 2, 3, 4, 5}, + }, + }, + { + Name: "list(float) empty", + Type: "Bucketize", + Input: []Input{ + makeConst([]float32{}), + }, + Attrs: map[string]interface{}{ + "boundaries": []float32(nil), + }, + }, + /* TODO(ashankar): debug this issue and add it back later. + { + Name: "list(type),list(shape)", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst(float32(1)), + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Float, Int32}, + "shapes": []Shape{ScalarShape(), MakeShape(1, 1)}, + }, + }, + { + Name: "list(type),list(shape) empty", + Type: "InfeedEnqueueTuple", + Input: []Input{ + OutputList([]Output{ + makeConst([][]int32{{2}}), + }), + }, + Attrs: map[string]interface{}{ + "dtypes": []DataType{Int32}, + "shapes": []Shape(nil), + }, + }, + { + Name: "list(type) empty,string empty,int", + Type: "_XlaSendFromHost", + Input: []Input{ + OutputList([]Output{}), + makeConst(""), + }, + Attrs: map[string]interface{}{ + "Tinputs": []DataType(nil), + "key": "", + "device_ordinal": int64(0), + }, + }, + */ + { + Name: "list(int),int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": []int64{1, 2}, + }, + }, + { + Name: "list(int) empty,int", + Type: "StringToHashBucketStrong", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "num_buckets": int64(2), + "key": ([]int64)(nil), + }, + }, + { + Name: "list(string),type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": []string{"foo", "bar"}, + }, + }, + { + Name: "list(string) empty,type", + Type: "TensorSummary", + Input: []Input{ + makeConst(""), + }, + Attrs: map[string]interface{}{ + "T": String, + "labels": ([]string)(nil), + }, + }, + { + Name: "tensor", + Type: "Const", + Attrs: map[string]interface{}{ + "dtype": String, + "value": makeTensor("foo"), + }, + }, + } + + for i, spec := range cases { + op, err := g.AddOperation(spec) + if err != nil { + t.Fatal(err) + } + for key, want := range spec.Attrs { + out, err := op.Attr(key) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(out, want) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want) + } + wantT, ok := want.(*Tensor) + if ok { + wantVal := wantT.Value() + outVal := out.(*Tensor).Value() + if !reflect.DeepEqual(outVal, wantVal) { + t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal) + } + } + } + } +} |