aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-28 19:13:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 19:16:41 -0700
commit1e7b0e4ad6d0f57f3241fe0b80a65f2c2a7f11b0 (patch)
treeaf92d172cedfc41e544c01a349c1d3b30bc3ff85 /tensorflow/go
parent3cee10e61c1c90734317c62ea3388ec44acc8d08 (diff)
Merge changes from github.
PiperOrigin-RevId: 202585094
Diffstat (limited to 'tensorflow/go')
-rw-r--r--tensorflow/go/attrs.go245
-rw-r--r--tensorflow/go/attrs_test.go193
-rw-r--r--tensorflow/go/op/wrappers.go9
-rw-r--r--tensorflow/go/operation.go66
-rw-r--r--tensorflow/go/operation_test.go62
5 files changed, 571 insertions, 4 deletions
diff --git a/tensorflow/go/attrs.go b/tensorflow/go/attrs.go
new file mode 100644
index 0000000000..f86c5737bc
--- /dev/null
+++ b/tensorflow/go/attrs.go
@@ -0,0 +1,245 @@
+/*
+Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package tensorflow
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+import (
+ "fmt"
+ "unsafe"
+)
+
+// makeCShape converts a shape specified in C.int64_t into a Shape.
+func makeCShape(shape []C.int64_t) Shape {
+ s := Shape{dims: make([]int64, len(shape))}
+ for i, n := range shape {
+ s.dims[i] = int64(n)
+ }
+ return s
+}
+
+// Attr returns the value of an attribute on op. It returns an error if the
+// attribute does not exist.
+func (op *Operation) Attr(name string) (interface{}, error) {
+ cname := C.CString(name)
+ defer C.free(unsafe.Pointer(cname))
+
+ status := newStatus()
+ meta := C.TF_OperationGetAttrMetadata(op.c, cname, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ if meta.is_list == 1 {
+ return listAttribute(op, cname, meta)
+ }
+ return scalarAttribute(op, cname, meta)
+}
+
+func listAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
+ status := newStatus()
+
+ switch meta._type {
+ case C.TF_ATTR_STRING:
+ if meta.list_size == 0 {
+ return []string(nil), nil
+ }
+ values := make([]unsafe.Pointer, meta.list_size)
+ lengths := make([]C.size_t, meta.list_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.char, meta.total_size+1)
+ C.TF_OperationGetAttrStringList(op.c, cname, &values[0], &lengths[0], C.int(meta.list_size), unsafe.Pointer(&storage[0]), C.size_t(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ list := make([]string, meta.list_size)
+ for i, val := range values {
+ length := lengths[i]
+ list[i] = C.GoStringN((*C.char)(val), C.int(length))
+ }
+ return list, nil
+
+ case C.TF_ATTR_INT:
+ if meta.list_size == 0 {
+ return []int64(nil), nil
+ }
+ list := make([]C.int64_t, meta.list_size)
+ C.TF_OperationGetAttrIntList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]int64, meta.list_size)
+ for i, val := range list {
+ vals[i] = int64(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_FLOAT:
+ if meta.list_size == 0 {
+ return []float32(nil), nil
+ }
+ list := make([]C.float, meta.list_size)
+ C.TF_OperationGetAttrFloatList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]float32, meta.list_size)
+ for i, val := range list {
+ vals[i] = float32(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_BOOL:
+ if meta.list_size == 0 {
+ return []bool(nil), nil
+ }
+ list := make([]C.uchar, meta.list_size)
+ C.TF_OperationGetAttrBoolList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]bool, meta.list_size)
+ for i, val := range list {
+ vals[i] = val == 1
+ }
+ return vals, nil
+
+ case C.TF_ATTR_TYPE:
+ if meta.list_size == 0 {
+ return []DataType(nil), nil
+ }
+ list := make([]C.TF_DataType, meta.list_size)
+ C.TF_OperationGetAttrTypeList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]DataType, meta.list_size)
+ for i, val := range list {
+ vals[i] = DataType(val)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_TENSOR:
+ if meta.list_size == 0 {
+ return []*Tensor(nil), nil
+ }
+ list := make([]*C.TF_Tensor, meta.list_size)
+ C.TF_OperationGetAttrTensorList(op.c, cname, &list[0], C.int(meta.list_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ vals := make([]*Tensor, meta.list_size)
+ for i, t := range list {
+ vals[i] = newTensorFromC(t)
+ }
+ return vals, nil
+
+ case C.TF_ATTR_SHAPE:
+ if meta.list_size == 0 {
+ return []Shape(nil), nil
+ }
+ dims := make([]*C.int64_t, meta.list_size)
+ numDims := make([]C.int, meta.list_size)
+ // Add one element in case total_size is zero.
+ storage := make([]C.int64_t, meta.total_size+1)
+ C.TF_OperationGetAttrShapeList(op.c, cname, &dims[0], &numDims[0], C.int(meta.list_size), &storage[0], C.int(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ list := make([]Shape, meta.list_size)
+ for i, dim := range dims {
+ numDim := numDims[i]
+ // If the number of dimensions is unknown, default to empty shape.
+ if numDim < 0 {
+ continue
+ }
+ // A []C.int64_t slice backed by C memory.
+ // See: https://github.com/golang/go/wiki/cgo#turning-c-arrays-into-go-slices
+ slice := (*[1 << 30]C.int64_t)(unsafe.Pointer(dim))[:numDim:numDim]
+ list[i] = makeCShape(slice)
+ }
+ return list, nil
+
+ default:
+ return nil, fmt.Errorf("list type %v not supported", meta._type)
+ }
+}
+
+func scalarAttribute(op *Operation, cname *C.char, meta C.TF_AttrMetadata) (interface{}, error) {
+ status := newStatus()
+
+ switch meta._type {
+ case C.TF_ATTR_STRING:
+ if meta.total_size == 0 {
+ return "", nil
+ }
+ v := make([]C.char, meta.total_size)
+ C.TF_OperationGetAttrString(op.c, cname, unsafe.Pointer(&v[0]), C.size_t(meta.total_size), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return C.GoStringN(&v[0], C.int(meta.total_size)), nil
+
+ case C.TF_ATTR_INT:
+ var v C.int64_t
+ C.TF_OperationGetAttrInt(op.c, cname, &v, status.c)
+ return int64(v), status.Err()
+
+ case C.TF_ATTR_FLOAT:
+ var v C.float
+ C.TF_OperationGetAttrFloat(op.c, cname, &v, status.c)
+ return float32(v), status.Err()
+
+ case C.TF_ATTR_BOOL:
+ var v C.uchar
+ C.TF_OperationGetAttrBool(op.c, cname, &v, status.c)
+ return v == 1, status.Err()
+
+ case C.TF_ATTR_TYPE:
+ var v C.TF_DataType
+ C.TF_OperationGetAttrType(op.c, cname, &v, status.c)
+ return DataType(v), status.Err()
+
+ case C.TF_ATTR_TENSOR:
+ var v *C.TF_Tensor
+ C.TF_OperationGetAttrTensor(op.c, cname, &v, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return newTensorFromC(v), nil
+
+ case C.TF_ATTR_SHAPE:
+ numDims := meta.total_size
+ // If number of dims is unknown return empty shape to indicate that.
+ if numDims < 0 {
+ return Shape{}, nil
+ }
+ if numDims == 0 {
+ return ScalarShape(), nil
+ }
+ dims := make([]C.int64_t, numDims)
+ C.TF_OperationGetAttrShape(op.c, cname, (*C.int64_t)(unsafe.Pointer(&dims[0])), C.int(numDims), status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ return makeCShape(dims), nil
+
+ default:
+ return nil, fmt.Errorf("type %v not supported", meta._type)
+ }
+}
diff --git a/tensorflow/go/attrs_test.go b/tensorflow/go/attrs_test.go
new file mode 100644
index 0000000000..ea8af221ae
--- /dev/null
+++ b/tensorflow/go/attrs_test.go
@@ -0,0 +1,193 @@
+/*
+Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package tensorflow
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+)
+
+func TestOperationAttrs(t *testing.T) {
+ g := NewGraph()
+
+ i := 0
+ makeConst := func(v interface{}) Output {
+ op, err := Const(g, fmt.Sprintf("const/%d/%+v", i, v), v)
+ i++
+ if err != nil {
+ t.Fatal(err)
+ }
+ return op
+ }
+
+ makeTensor := func(v interface{}) *Tensor {
+ tensor, err := NewTensor(v)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return tensor
+ }
+
+ cases := []OpSpec{
+ {
+ Name: "type",
+ Type: "Placeholder",
+ Attrs: map[string]interface{}{
+ "dtype": Float,
+ },
+ },
+ {
+ Name: "list(float)",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{1, 2, 3, 4}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32{0, 1, 2, 3, 4, 5},
+ },
+ },
+ {
+ Name: "list(float) empty",
+ Type: "Bucketize",
+ Input: []Input{
+ makeConst([]float32{}),
+ },
+ Attrs: map[string]interface{}{
+ "boundaries": []float32(nil),
+ },
+ },
+ /* TODO(ashankar): debug this issue and add it back later.
+ {
+ Name: "list(type),list(shape)",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst(float32(1)),
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Float, Int32},
+ "shapes": []Shape{ScalarShape(), MakeShape(1, 1)},
+ },
+ },
+ {
+ Name: "list(type),list(shape) empty",
+ Type: "InfeedEnqueueTuple",
+ Input: []Input{
+ OutputList([]Output{
+ makeConst([][]int32{{2}}),
+ }),
+ },
+ Attrs: map[string]interface{}{
+ "dtypes": []DataType{Int32},
+ "shapes": []Shape(nil),
+ },
+ },
+ {
+ Name: "list(type) empty,string empty,int",
+ Type: "_XlaSendFromHost",
+ Input: []Input{
+ OutputList([]Output{}),
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "Tinputs": []DataType(nil),
+ "key": "",
+ "device_ordinal": int64(0),
+ },
+ },
+ */
+ {
+ Name: "list(int),int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": []int64{1, 2},
+ },
+ },
+ {
+ Name: "list(int) empty,int",
+ Type: "StringToHashBucketStrong",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "num_buckets": int64(2),
+ "key": ([]int64)(nil),
+ },
+ },
+ {
+ Name: "list(string),type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": []string{"foo", "bar"},
+ },
+ },
+ {
+ Name: "list(string) empty,type",
+ Type: "TensorSummary",
+ Input: []Input{
+ makeConst(""),
+ },
+ Attrs: map[string]interface{}{
+ "T": String,
+ "labels": ([]string)(nil),
+ },
+ },
+ {
+ Name: "tensor",
+ Type: "Const",
+ Attrs: map[string]interface{}{
+ "dtype": String,
+ "value": makeTensor("foo"),
+ },
+ },
+ }
+
+ for i, spec := range cases {
+ op, err := g.AddOperation(spec)
+ if err != nil {
+ t.Fatal(err)
+ }
+ for key, want := range spec.Attrs {
+ out, err := op.Attr(key)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(out, want) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, out, want)
+ }
+ wantT, ok := want.(*Tensor)
+ if ok {
+ wantVal := wantT.Value()
+ outVal := out.(*Tensor).Value()
+ if !reflect.DeepEqual(outVal, wantVal) {
+ t.Fatalf("%d. %q: Got %#v, wanted %#v", i, key, outVal, wantVal)
+ }
+ }
+ }
+ }
+}
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index b2dbdafc5f..6d9cb7c6ec 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -11210,7 +11210,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
//
// value: The cropped area of the image must contain a fraction of the
-// supplied image within in this range.
+// supplied image within this range.
// If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
@@ -17969,9 +17969,10 @@ func SparseFillEmptyRowsGrad(scope *Scope, reverse_index_map tf.Output, grad_val
}
// Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
-//
// if < 0, `scale * features` otherwise.
//
+// Assumes weights to have zero mean and variance 1.0 / fan_in.
+//
// See [Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515)
func Selu(scope *Scope, features tf.Output) (activations tf.Output) {
if scope.Err() != nil {
@@ -21655,7 +21656,7 @@ func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
// generated sequentially as '*tag*/image/0', '*tag*/image/1', etc.
//
// The `bad_color` argument is the color to use in the generated images for
-// non-finite input values. It is a `unit8` 1-D tensor of length `channels`.
+// non-finite input values. It is a `uint8` 1-D tensor of length `channels`.
// Each element must be in the range `[0, 255]` (It represents the value of a
// pixel in the output image). Non-finite values in the input tensor are
// replaced by this tensor in the output image. The default value is the color
@@ -24048,7 +24049,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
// SampleDistortedBoundingBoxV2AreaRange sets the optional area_range attribute to value.
//
// value: The cropped area of the image must contain a fraction of the
-// supplied image within in this range.
+// supplied image within this range.
// If not specified, defaults to <f:0.05 f:1 >
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 8fcad61f4c..25ec718703 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output {
return Output{op, i}
}
+// NumInputs returns the number of inputs of op.
+func (op *Operation) NumInputs() int {
+ return int(C.TF_OperationNumInputs(op.c))
+}
+
// Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to
@@ -123,6 +128,67 @@ func (p Output) c() C.TF_Output {
func (p Output) canBeAnInput() {}
+// Consumers returns the inputs that consume this output.
+func (p Output) Consumers() []Consumer {
+ max := int(C.TF_OperationOutputNumConsumers(p.c()))
+ if max == 0 {
+ return nil
+ }
+ inputs := make([]C.TF_Input, max)
+ n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
+ inputs = inputs[:int(n)]
+
+ var consumers []Consumer
+ for _, consumer := range inputs {
+ consumers = append(consumers, Consumer{
+ Index: int(consumer.index),
+ Op: &Operation{
+ c: consumer.oper,
+ g: p.Op.g,
+ },
+ })
+ }
+
+ return consumers
+}
+
+// Consumer identifies a specific input of an operation that consumes the output
+// of another operation.
+type Consumer struct {
+ // Op is the Operation that is consuming the output of another operation.
+ Op *Operation
+
+ // Index is the index of the input within Op that the output of another
+ // operation is connected to.
+ Index int
+}
+
+func (p Consumer) c() C.TF_Input {
+ if p.Op == nil {
+ // Attempt to provide a more useful panic message than "nil
+ // pointer dereference".
+ panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
+ }
+ return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
+}
+
+// DataType returns the type of the input.
+func (p Consumer) DataType() DataType {
+ return DataType(C.TF_OperationInputType(p.c()))
+}
+
+// Producer returns the Output that is connected to this Consumer.
+func (p Consumer) Producer() Output {
+ output := C.TF_OperationInput(p.c())
+ return Output{
+ Op: &Operation{
+ c: output.oper,
+ g: p.Op.g,
+ },
+ Index: int(output.index),
+ }
+}
+
// Input is the interface for specifying inputs to an operation being added to
// a Graph.
//
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 40c951ab8c..06b65bdfb7 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -166,6 +166,68 @@ func TestOutputDataTypeAndShape(t *testing.T) {
}
}
+func TestOperationInputs(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ y, err := Placeholder(g, "y", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ add, err := Add(g, "add", x, y)
+ if err != nil {
+ t.Fatal(err)
+ }
+ addOp := add.Op
+
+ if out := addOp.NumInputs(); out != 2 {
+ t.Fatalf("Got %d inputs, wanted 2", out)
+ }
+}
+
+func TestOperationConsumers(t *testing.T) {
+ g := NewGraph()
+ x, err := Placeholder(g, "x", Float)
+ if err != nil {
+ t.Fatal(err)
+ }
+ a, err := Neg(g, "a", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+ b, err := Neg(g, "b", x)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ consumers := []*Operation{a.Op, b.Op}
+
+ xConsumers := x.Consumers()
+ if out := len(xConsumers); out != 2 {
+ t.Fatalf("Got %d consumers, wanted 2", out)
+ }
+
+ for i, consumer := range xConsumers {
+ got := consumer.Op.Name()
+ want := consumers[i].Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+
+ got = consumer.Producer().Op.Name()
+ want = x.Op.Name()
+ if got != want {
+ t.Fatalf("%d. Got op name %q, wanted %q", i, got, want)
+ }
+ }
+
+ if len(b.Consumers()) != 0 {
+ t.Fatalf("expected %+v to have no consumers", b)
+ }
+}
+
func forceGC() {
var mem runtime.MemStats
runtime.ReadMemStats(&mem)