diff options
-rw-r--r-- | tensorflow/go/graph.go | 160 | ||||
-rw-r--r-- | tensorflow/go/operation.go | 63 | ||||
-rw-r--r-- | tensorflow/go/operation_test.go | 40 | ||||
-rw-r--r-- | tensorflow/go/session_test.go | 95 | ||||
-rw-r--r-- | tensorflow/go/util_test.go | 45 |
5 files changed, 297 insertions, 106 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index e119839218..c230595b15 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -109,5 +109,163 @@ func (g *Graph) Operation(name string) *Operation { if cop == nil { return nil } - return &Operation{cop,g} + return &Operation{cop, g} +} + +// OpSpec is the specification of an Operation to be added to a Graph +// (using Graph.AddOperation). +type OpSpec struct { + // Type of the operation (e.g., "Add", "MatMul"). + Type string + + // Name by which the added operation will be referred to in the Graph. + // If omitted, defaults to Type. + Name string + + // Inputs to this operation, which in turn must be outputs + // of other operations already added to the Graph. + // + // An operation may have multiple inputs with individual inputs being + // either a single tensor produced by another operation or a list of + // tensors produced by multiple operations. For example, the "Concat" + // operation takes two inputs: (1) the dimension along which to + // concatenate and (2) a list of tensors to concatenate. Thus, for + // Concat, len(Input) must be 2, with the first element being an Output + // and the second being an OutputList. + Input []Input + + // Map from attribute name to its value that will be attached to this + // operation. + Attrs map[string]interface{} + + // Other possible fields: Device, ColocateWith, ControlInputs. +} + +// AddOperation adds an operation to g. +func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { + if args.Name == "" { + args.Name = args.Type + } + cname := C.CString(args.Name) + ctype := C.CString(args.Type) + cdesc := C.TF_NewOperation(g.c, ctype, cname) + C.free(unsafe.Pointer(cname)) + C.free(unsafe.Pointer(ctype)) + + for _, in := range args.Input { + switch in := in.(type) { + case Output: + C.TF_AddInput(cdesc, in.c()) + case OutputList: + size := len(in) + list := make([]C.TF_Port, size) + for i, v := range in { + list[i] = v.c() + } + C.TF_AddInputList(cdesc, &list[0], C.int(size)) + } + } + status := newStatus() + for name, value := range args.Attrs { + if err := setAttr(cdesc, status, name, value); err != nil { + // Memory leak here as the TF_OperationDescription + // object will not be cleaned up. At the time of this + // writing, this was next to impossible since it + // required value to be a string tensor with + // incorrectly encoded strings. Given this rarity, live + // with the memory leak. If it becomes a real problem, + // consider adding a TF_DeleteOperationDescription + // function to the C API. + return nil, fmt.Errorf("%v (memory will be leaked)", err) + } + } + op := &Operation{ + c: C.TF_FinishOperation(cdesc, status.c), + g: g, + } + return op, status.Err() +} + +func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { + cAttrName := C.CString(name) + defer C.free(unsafe.Pointer(cAttrName)) + switch value := value.(type) { + case string: + cstr := C.CString(value) + C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.int(len(value))) + C.free(unsafe.Pointer(cstr)) + case []string: + size := len(value) + list := make([]unsafe.Pointer, size) + lens := make([]C.int, size) + for i, s := range value { + list[i] = unsafe.Pointer(C.CString(s)) + lens[i] = C.int(len(s)) + } + C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) + for _, s := range list { + C.free(s) + } + case int64: + C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) + case []int64: + size := len(value) + list := make([]C.int64_t, size) + for i, v := range value { + list[i] = C.int64_t(v) + } + C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) + case float32: + C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) + case []float32: + size := len(value) + list := make([]C.float, size) + for i, v := range value { + list[i] = C.float(v) + } + C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) + case bool: + v := C.uchar(0) + if value { + v = 1 + } + C.TF_SetAttrBool(cdesc, cAttrName, v) + case []bool: + size := len(value) + list := make([]C.uchar, size) + for i, v := range value { + if v { + list[i] = 1 + } + } + C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) + case DataType: + C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) + case []DataType: + list := (*C.TF_DataType)(&value[0]) + C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) + case *Tensor: + C.TF_SetAttrTensor(cdesc, cAttrName, value.c(), status.c) + if err := status.Err(); err != nil { + return fmt.Errorf("bad value for attribute %q: %v", name, err) + } + case []*Tensor: + size := len(value) + list := make([]*C.TF_Tensor, size) + for i, v := range value { + list[i] = v.c() + } + C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c) + if err := status.Err(); err != nil { + return fmt.Errorf("bad value for attribute %q: %v", name, err) + } + default: + // Shapes can be done, but will require that it be + // distinguishable from []int64. Which is fine, it + // probably makes sense to define a Shape type anyway, + // since that should handle partially known shapes as + // well and hide the special meaning of -1? + return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) + } + return nil } diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 59c5c2a2b0..8f4fee1cbe 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -20,7 +20,6 @@ import "C" import ( "errors" - "unsafe" ) // Operation that has been added to the graph. @@ -47,8 +46,6 @@ func (op *Operation) NumOutputs() int { } // Output returns the i-th output of op. -// -// REQUIRES: 0 <= i < op.NumOutputs() func (op *Operation) Output(i int) Output { return Output{op, i} } @@ -97,52 +94,28 @@ func (p Output) Shape() (shape []int64, err error) { return ret, nil } -func (p *Output) c() C.TF_Port { +func (p Output) c() C.TF_Port { return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)} } -// opBuilder is for use by the generated op code to create new Operations. -// Build() must be called for any in-progress Operation, or else we leak. -type opBuilder struct { - c *C.TF_OperationDescription - // A reference to the Graph to prevent it from being GCed while - // the opBuilder is still alive. - g *Graph -} - -func newOpBuilder(g *Graph, typ string, name string) *opBuilder { - opType := C.CString(typ) - opName := C.CString(name) - b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName), g: g} - C.free(unsafe.Pointer(opType)) - C.free(unsafe.Pointer(opName)) - return b -} - -func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error { - status := newStatus() - attrName := C.CString(name) - C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c) - C.free(unsafe.Pointer(attrName)) - return status.Err() -} +func (p Output) canBeAnInput() {} -func (b *opBuilder) SetAttrType(name string, typ DataType) { - attrName := C.CString(name) - C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ)) - C.free(unsafe.Pointer(attrName)) +// Input is the interface for specifying inputs to an operation being added to +// a Graph. +// +// Operations can have multiple inputs, each of which could be either a tensor +// produced by another operation (an Output object), or a list of tensors +// produced by other operations (an OutputList). Thus, this interface is +// implemented by both Output and OutputList. +// +// See OpSpec.Input for more information. +type Input interface { + // Unexported to preclude implementations outside this package. + canBeAnInput() } -func (b *opBuilder) AddInput(port Output) { - C.TF_AddInput(b.c, port.c()) -} +// OutputList represents a list of Outputs that can be provided as input to +// another operation. +type OutputList []Output -func (b *opBuilder) Build() (*Operation, error) { - status := newStatus() - op := &Operation{c: C.TF_FinishOperation(b.c, status.c), g: b.g} - if err := status.Err(); err != nil { - return nil, err - } - b.c = nil - return op, nil -} +func (l OutputList) canBeAnInput() {} diff --git a/tensorflow/go/operation_test.go b/tensorflow/go/operation_test.go index 35a5d95357..1f04528e1b 100644 --- a/tensorflow/go/operation_test.go +++ b/tensorflow/go/operation_test.go @@ -74,30 +74,24 @@ func TestOutputShape(t *testing.T) { }, } for idx, test := range testdata { - tensor, err := NewTensor(test.Value) - if err != nil { - t.Errorf("#%d: NewTensor(%T) failed: %v", idx, test.Value, err) - continue - } - c, err := Const(graph, fmt.Sprintf("test%d", idx), tensor) - if err != nil { - t.Errorf("#%d: Const(%T) failed: %v", idx, test.Value, err) - continue - } - shape, err := c.Shape() - if err != nil { - t.Errorf("#%d: Shape() failed for %T: %v", idx, test.Value, err) - continue - } - if got, want := len(shape), len(test.Shape); got != want { - t.Errorf("#%d: %T: Got a shape with %d dimensions, want %d", idx, test.Value, got, want) - continue - } - for i := 0; i < len(shape); i++ { - if got, want := shape[i], test.Shape[i]; got != want { - t.Errorf("#%d: %T: Got %d, want %d for dimension #%d/%d", idx, test.Value, got, want, i, len(shape)) + t.Run(fmt.Sprintf("#%d Value %T", idx, test.Value), func(t *testing.T) { + c, err := Const(graph, fmt.Sprintf("const%d", idx), test.Value) + if err != nil { + t.Fatal(err) } - } + shape, err := c.Shape() + if err != nil { + t.Fatal(err) + } + if got, want := len(shape), len(test.Shape); got != want { + t.Fatalf("Got a shape with %d dimensions, want %d", got, want) + } + for i := 0; i < len(shape); i++ { + if got, want := shape[i], test.Shape[i]; got != want { + t.Errorf("Got %d, want %d for dimension #%d/%d", got, want, i, len(shape)) + } + } + }) } // Unknown number of dimensions dummyTensor, err := NewTensor(float64(0)) diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go index 6e823d1841..0d3660995b 100644 --- a/tensorflow/go/session_test.go +++ b/tensorflow/go/session_test.go @@ -15,6 +15,7 @@ package tensorflow import ( + "fmt" "reflect" "testing" ) @@ -43,30 +44,78 @@ func TestSessionRunNeg(t *testing.T) { } for _, test := range tests { - t1, err := NewTensor(test.input) - if err != nil { - t.Fatalf("NewTensor(%v): %v", test.input, err) - } - graph, inp, out := createTestGraph(t, t1.DataType()) - s, err := NewSession(graph, &SessionOptions{}) - if err != nil { - t.Fatalf("NewSession() for %v: %v", test.input, err) - } - output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op}) - if err != nil { - t.Fatalf("Run() for %v: %v", test.input, err) - } - if len(output) != 1 { - t.Errorf("%v: got %d outputs, want 1", test.input, len(output)) - continue - } - val := output[0].Value() - if !reflect.DeepEqual(test.expected, val) { - t.Errorf("got %v, want %v", val, test.expected) - } - if err := s.Close(); err != nil { - t.Errorf("Close(): %v", err) + t.Run(fmt.Sprint(test.input), func(t *testing.T) { + t1, err := NewTensor(test.input) + if err != nil { + t.Fatal(err) + } + graph, inp, out := createTestGraph(t, t1.DataType()) + s, err := NewSession(graph, &SessionOptions{}) + if err != nil { + t.Fatal(err) + } + output, err := s.Run(map[Output]*Tensor{inp: t1}, []Output{out}, []*Operation{out.Op}) + if err != nil { + t.Fatal(err) + } + if len(output) != 1 { + t.Fatalf("got %d outputs, want 1", len(output)) + } + val := output[0].Value() + if !reflect.DeepEqual(test.expected, val) { + t.Errorf("got %v, want %v", val, test.expected) + } + if err := s.Close(); err != nil { + t.Error(err) + } + }) + } +} + +func TestSessionRunConcat(t *testing.T) { + // Runs the Concat operation on two matrices: m1 and m2, along the + // first dimension (dim1). + // This tests the use of both Output and OutputList as inputs to the + // Concat operation. + var ( + g = NewGraph() + dim1, _ = Const(g, "dim1", int32(1)) + m1, _ = Const(g, "m1", [][]int64{ + {1, 2, 3}, + {4, 5, 6}, + }) + m2, _ = Const(g, "m2", [][]int64{ + {7, 8, 9}, + {10, 11, 12}, + }) + want = [][]int64{ + {1, 2, 3, 7, 8, 9}, + {4, 5, 6, 10, 11, 12}, } + ) + concat, err := g.AddOperation(OpSpec{ + Type: "Concat", + Input: []Input{ + dim1, + OutputList{m1, m2}, + }, + }) + if err != nil { + t.Fatal(err) + } + s, err := NewSession(g, &SessionOptions{}) + if err != nil { + t.Fatal(err) + } + output, err := s.Run(nil, []Output{concat.Output(0)}, nil) + if err != nil { + t.Fatal(err) + } + if len(output) != 1 { + t.Fatal(len(output)) + } + if got := output[0].Value(); !reflect.DeepEqual(got, want) { + t.Fatalf("Got %v, want %v", got, want) } } diff --git a/tensorflow/go/util_test.go b/tensorflow/go/util_test.go index 06b4e61d0f..8ab365c656 100644 --- a/tensorflow/go/util_test.go +++ b/tensorflow/go/util_test.go @@ -15,23 +15,40 @@ package tensorflow func Placeholder(g *Graph, name string, dt DataType) (Output, error) { - b := newOpBuilder(g, "Placeholder", name) - b.SetAttrType("dtype", dt) - op, err := b.Build() - return Output{op, 0}, err + op, err := g.AddOperation(OpSpec{ + Type: "Placeholder", + Name: name, + Attrs: map[string]interface{}{ + "dtype": dt, + }, + }) + return op.Output(0), err } -func Const(g *Graph, name string, t *Tensor) (Output, error) { - b := newOpBuilder(g, "Const", name) - b.SetAttrType("dtype", t.DataType()) - b.SetAttrTensor("value", t) - op, err := b.Build() - return Output{op, 0}, err +func Const(g *Graph, name string, value interface{}) (Output, error) { + t, ok := value.(*Tensor) + if !ok { + var err error + if t, err = NewTensor(value); err != nil { + return Output{}, err + } + } + op, err := g.AddOperation(OpSpec{ + Type: "Const", + Name: name, + Attrs: map[string]interface{}{ + "dtype": t.DataType(), + "value": t, + }, + }) + return op.Output(0), err } func Neg(g *Graph, name string, port Output) (Output, error) { - b := newOpBuilder(g, "Neg", name) - b.AddInput(port) - op, err := b.Build() - return Output{op, 0}, err + op, err := g.AddOperation(OpSpec{ + Type: "Neg", + Name: name, + Input: []Input{port}, + }) + return op.Output(0), err } |