aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-07 09:56:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 11:04:39 -0700
commitb5b1671163c90cf786fe0c65ca6754a85d017847 (patch)
tree313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go
parent318948b28076540841f0c079f12c8f4fdd3d15f1 (diff)
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph. I also intend to use this API in a code generator that will generate Go sources files containing functions for each OpDef (and all this generated code will be in a separate package). While at it, also changed some tests to use the "sub-tests" feature in Go 1.7 (https://blog.golang.org/subtests) Another step in the journey of #10 Change: 135493412
Diffstat (limited to 'tensorflow/go')
-rw-r--r--tensorflow/go/graph.go160
-rw-r--r--tensorflow/go/operation.go63
-rw-r--r--tensorflow/go/operation_test.go40
-rw-r--r--tensorflow/go/session_test.go95
-rw-r--r--tensorflow/go/util_test.go45
5 files changed, 297 insertions, 106 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index e119839218..c230595b15 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -109,5 +109,163 @@ func (g *Graph) Operation(name string) *Operation {
if cop == nil {
return nil
}
- return &Operation{cop,g}
+ return &Operation{cop, g}
+}
+
+// OpSpec is the specification of an Operation to be added to a Graph
+// (using Graph.AddOperation).
+type OpSpec struct {
+ // Type of the operation (e.g., "Add", "MatMul").
+ Type string
+
+ // Name by which the added operation will be referred to in the Graph.
+ // If omitted, defaults to Type.
+ Name string
+
+ // Inputs to this operation, which in turn must be outputs
+ // of other operations already added to the Graph.
+ //
+ // An operation may have multiple inputs with individual inputs being
+ // either a single tensor produced by another operation or a list of
+ // tensors produced by multiple operations. For example, the "Concat"
+ // operation takes two inputs: (1) the dimension along which to
+ // concatenate and (2) a list of tensors to concatenate. Thus, for
+ // Concat, len(Input) must be 2, with the first element being an Output
+ // and the second being an OutputList.
+ Input []Input
+
+ // Map from attribute name to its value that will be attached to this
+ // operation.
+ Attrs map[string]interface{}
+
+ // Other possible fields: Device, ColocateWith, ControlInputs.
+}
+
+// AddOperation adds an operation to g.
+func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
+ if args.Name == "" {
+ args.Name = args.Type
+ }
+ cname := C.CString(args.Name)
+ ctype := C.CString(args.Type)
+ cdesc := C.TF_NewOperation(g.c, ctype, cname)
+ C.free(unsafe.Pointer(cname))
+ C.free(unsafe.Pointer(ctype))
+
+ for _, in := range args.Input {
+ switch in := in.(type) {
+ case Output:
+ C.TF_AddInput(cdesc, in.c())
+ case OutputList:
+ size := len(in)
+ list := make([]C.TF_Port, size)
+ for i, v := range in {
+ list[i] = v.c()
+ }
+ C.TF_AddInputList(cdesc, &list[0], C.int(size))
+ }
+ }
+ status := newStatus()
+ for name, value := range args.Attrs {
+ if err := setAttr(cdesc, status, name, value); err != nil {
+ // Memory leak here as the TF_OperationDescription
+ // object will not be cleaned up. At the time of this
+ // writing, this was next to impossible since it
+ // required value to be a string tensor with
+ // incorrectly encoded strings. Given this rarity, live
+ // with the memory leak. If it becomes a real problem,
+ // consider adding a TF_DeleteOperationDescription
+ // function to the C API.
+ return nil, fmt.Errorf("%v (memory will be leaked)", err)
+ }
+ }
+ op := &Operation{
+ c: C.TF_FinishOperation(cdesc, status.c),
+ g: g,
+ }
+ return op, status.Err()
+}
+
+func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
+ cAttrName := C.CString(name)
+ defer C.free(unsafe.Pointer(cAttrName))
+ switch value := value.(type) {
+ case string:
+ cstr := C.CString(value)
+ C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.int(len(value)))
+ C.free(unsafe.Pointer(cstr))
+ case []string:
+ size := len(value)
+ list := make([]unsafe.Pointer, size)
+ lens := make([]C.int, size)
+ for i, s := range value {
+ list[i] = unsafe.Pointer(C.CString(s))
+ lens[i] = C.int(len(s))
+ }
+ C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
+ for _, s := range list {
+ C.free(s)
+ }
+ case int64:
+ C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
+ case []int64:
+ size := len(value)
+ list := make([]C.int64_t, size)
+ for i, v := range value {
+ list[i] = C.int64_t(v)
+ }
+ C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
+ case float32:
+ C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
+ case []float32:
+ size := len(value)
+ list := make([]C.float, size)
+ for i, v := range value {
+ list[i] = C.float(v)
+ }
+ C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
+ case bool:
+ v := C.uchar(0)
+ if value {
+ v = 1
+ }
+ C.TF_SetAttrBool(cdesc, cAttrName, v)
+ case []bool:
+ size := len(value)
+ list := make([]C.uchar, size)
+ for i, v := range value {
+ if v {
+ list[i] = 1
+ }
+ }
+ C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
+ case DataType:
+ C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
+ case []DataType:
+ list := (*C.TF_DataType)(&value[0])
+ C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
+ case *Tensor:
+ C.TF_SetAttrTensor(cdesc, cAttrName, value.c(), status.c)
+ if err := status.Err(); err != nil {
+ return fmt.Errorf("bad value for attribute %q: %v", name, err)
+ }
+ case []*Tensor:
+ size := len(value)
+ list := make([]*C.TF_Tensor, size)
+ for i, v := range value {
+ list[i] = v.c()
+ }
+ C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c)
+ if err := status.Err(); err != nil {
+ return fmt.Errorf("bad value for attribute %q: %v", name, err)
+ }
+ default:
+ // Shapes can be done, but will require that it be
+ // distinguishable from []int64. Which is fine, it
+ // probably makes sense to define a Shape type anyway,
+ // since that should handle partially known shapes as
+ // well and hide the special meaning of -1?
+ return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
+ }
+ return nil
}
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 59c5c2a2b0..8f4fee1cbe 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -20,7 +20,6 @@ import "C"
import (
"errors"
- "unsafe"
)
// Operation that has been added to the graph.
@@ -47,8 +46,6 @@ func (op *Operation) NumOutputs() int {
}
// Output returns the i-th output of op.
-//
-// REQUIRES: 0 <= i < op.NumOutputs()
func (op *Operation) Output(i int) Output {
return Output{op, i}
}
@@ -97,52 +94,28 @@ func (p Output) Shape() (shape []int64, err error) {
return ret, nil
}
-func (p *Output) c() C.TF_Port {
+func (p Output) c() C.TF_Port {
return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)}
}
-// opBuilder is for use by the generated op code to create new Operations.
-// Build() must be called for any in-progress Operation, or else we leak.
-type opBuilder struct {
- c *C.TF_OperationDescription
- // A reference to the Graph to prevent it from being GCed while
- // the opBuilder is still alive.
- g *Graph
-}
-
-func newOpBuilder(g *Graph, typ string, name string) *opBuilder {
- opType := C.CString(typ)
- opName := C.CString(name)
- b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName), g: g}
- C.free(unsafe.Pointer(opType))
- C.free(unsafe.Pointer(opName))
- return b
-}
-
-func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error {
- status := newStatus()
- attrName := C.CString(name)
- C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c)
- C.free(unsafe.Pointer(attrName))
- return status.Err()
-}
+func (p Output) canBeAnInput() {}
-func (b *opBuilder) SetAttrType(name string, typ DataType) {
- attrName := C.CString(name)
- C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ))
- C.free(unsafe.Pointer(attrName))
+// Input is the interface for specifying inputs to an operation being added to
+// a Graph.
+//
+// Operations can have multiple inputs, each of which could be either a tensor
+// produced by another operation (an Output object), or a list of tensors
+// produced by other operations (an OutputList). Thus, this interface is
+// implemented by both Output and OutputList.
+//
+// See OpSpec.Input for more information.
+type Input interface {
+ // Unexported to preclude implementations outside this package.
+ canBeAnInput()
}
-func (b *opBuilder) AddInput(port Output) {
- C.TF_AddInput(b.c, port.c())
-}
+// OutputList represents a list of Outputs that can be provided as input to
+// another operation.
+type OutputList []Output
-func (b *opBuilder) Build() (*Operation, error) {
- status := newStatus()
- op := &Operation{c: C.TF_FinishOperation(b.c, status.c), g: b.g}
- if err := status.Err(); err != nil {
- return nil, err
- }
- b.c = nil
- return op, nil
-}
+func (l OutputList) canBeAnInput() {}
diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go
index 35a5d95357..1f04528e1b 100644
--- a/tensorflow/go/operation_test.go
+++ b/tensorflow/go/operation_test.go
@@ -74,30 +74,24 @@ func TestOutputShape(t *testing.T) {
},
}
for idx, test := range testdata {
- tensor, err := NewTensor(test.Value)
- if err != nil {
- t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err)
- continue
- }
- c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor)
- if err != nil {
- t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err)
- continue
- }
- shape, err := c.Shape()
- if err != nil {
- t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err)
- continue
- }
- if got, want := len(shape), len(test.Shape); got != want {
- t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want)
- continue
- }
- for i := 0; i < len(shape); i++ {
- if got, want := shape[i], test.Shape[i]; got != want {
- t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape))
+ t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) {
+ c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value)
+ if err != nil {
+ t.Fatal(err)
}
- }
+ shape, err := c.Shape()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(shape), len(test.Shape); got != want {
+ t.Fatalf("Got a shape with %d dimensions, want %d", got, want)
+ }
+ for i := 0; i < len(shape); i++ {
+ if got, want := shape[i], test.Shape[i]; got != want {
+ t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(shape))
+ }
+ }
+ })
}
// Unknown number of dimensions
dummyTensor, err := NewTensor(float64(0))
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
index 6e823d1841..0d3660995b 100644
--- a/tensorflow/go/session_test.go
+++ b/tensorflow/go/session_test.go
@@ -15,6 +15,7 @@
package tensorflow
import (
+ "fmt"
"reflect"
"testing"
)
@@ -43,30 +44,78 @@ func TestSessionRunNeg(t *testing.T) {
}
for _, test := range tests {
- t1, err := NewTensor(test.input)
- if err != nil {
- t.Fatalf("NewTensor(%v): %v", test.input, err)
- }
- graph, inp, out := createTestGraph(t, t1.DataType())
- s, err := NewSession(graph, &SessionOptions{})
- if err != nil {
- t.Fatalf("NewSession() for %v: %v", test.input, err)
- }
- output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op})
- if err != nil {
- t.Fatalf("Run() for %v: %v", test.input, err)
- }
- if len(output) != 1 {
- t.Errorf("%v: got %d outputs, want 1", test.input, len(output))
- continue
- }
- val := output[0].Value()
- if !reflect.DeepEqual(test.expected, val) {
- t.Errorf("got %v, want %v", val, test.expected)
- }
- if err := s.Close(); err != nil {
- t.Errorf("Close(): %v", err)
+ t.Run(fmt.Sprint(test.input), func(t *testing.T) {
+ t1, err := NewTensor(test.input)
+ if err != nil {
+ t.Fatal(err)
+ }
+ graph, inp, out := createTestGraph(t, t1.DataType())
+ s, err := NewSession(graph, &SessionOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op})
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(output) != 1 {
+ t.Fatalf("got %d outputs, want 1", len(output))
+ }
+ val := output[0].Value()
+ if !reflect.DeepEqual(test.expected, val) {
+ t.Errorf("got %v, want %v", val, test.expected)
+ }
+ if err := s.Close(); err != nil {
+ t.Error(err)
+ }
+ })
+ }
+}
+
+func TestSessionRunConcat(t *testing.T) {
+ // Runs the Concat operation on two matrices: m1 and m2, along the
+ // first dimension (dim1).
+ // This tests the use of both Output and OutputList as inputs to the
+ // Concat operation.
+ var (
+ g = NewGraph()
+ dim1, _ = Const(g, "dim1", int32(1))
+ m1, _ = Const(g, "m1", [][]int64{
+ {1, 2, 3},
+ {4, 5, 6},
+ })
+ m2, _ = Const(g, "m2", [][]int64{
+ {7, 8, 9},
+ {10, 11, 12},
+ })
+ want = [][]int64{
+ {1, 2, 3, 7, 8, 9},
+ {4, 5, 6, 10, 11, 12},
}
+ )
+ concat, err := g.AddOperation(OpSpec{
+ Type: "Concat",
+ Input: []Input{
+ dim1,
+ OutputList{m1, m2},
+ },
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ s, err := NewSession(g, &SessionOptions{})
+ if err != nil {
+ t.Fatal(err)
+ }
+ output, err := s.Run(nil, []Output{concat.Output(0)}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(output) != 1 {
+ t.Fatal(len(output))
+ }
+ if got := output[0].Value(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("Got %v, want %v", got, want)
}
}
diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go
index 06b4e61d0f..8ab365c656 100644
--- a/tensorflow/go/util_test.go
+++ b/tensorflow/go/util_test.go
@@ -15,23 +15,40 @@
package tensorflow
func Placeholder(g *Graph, name string, dt DataType) (Output, error) {
- b := newOpBuilder(g, "Placeholder", name)
- b.SetAttrType("dtype", dt)
- op, err := b.Build()
- return Output{op, 0}, err
+ op, err := g.AddOperation(OpSpec{
+ Type: "Placeholder",
+ Name: name,
+ Attrs: map[string]interface{}{
+ "dtype": dt,
+ },
+ })
+ return op.Output(0), err
}
-func Const(g *Graph, name string, t *Tensor) (Output, error) {
- b := newOpBuilder(g, "Const", name)
- b.SetAttrType("dtype", t.DataType())
- b.SetAttrTensor("value", t)
- op, err := b.Build()
- return Output{op, 0}, err
+func Const(g *Graph, name string, value interface{}) (Output, error) {
+ t, ok := value.(*Tensor)
+ if !ok {
+ var err error
+ if t, err = NewTensor(value); err != nil {
+ return Output{}, err
+ }
+ }
+ op, err := g.AddOperation(OpSpec{
+ Type: "Const",
+ Name: name,
+ Attrs: map[string]interface{}{
+ "dtype": t.DataType(),
+ "value": t,
+ },
+ })
+ return op.Output(0), err
}
func Neg(g *Graph, name string, port Output) (Output, error) {
- b := newOpBuilder(g, "Neg", name)
- b.AddInput(port)
- op, err := b.Build()
- return Output{op, 0}, err
+ op, err := g.AddOperation(OpSpec{
+ Type: "Neg",
+ Name: name,
+ Input: []Input{port},
+ })
+ return op.Output(0), err
}