diff options
author | Asim Shankar <ashankar@google.com> | 2016-10-07 09:56:00 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-10-07 11:04:39 -0700 |
commit | b5b1671163c90cf786fe0c65ca6754a85d017847 (patch) | |
tree | 313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/graph.go | |
parent | 318948b28076540841f0c079f12c8f4fdd3d15f1 (diff) |
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph.
I also intend to use this API in a code generator that will generate
Go sources files containing functions for each OpDef (and all
this generated code will be in a separate package).
While at it, also changed some tests to use the "sub-tests"
feature in Go 1.7 (https://blog.golang.org/subtests)
Another step in the journey of #10
Change: 135493412
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r-- | tensorflow/go/graph.go | 160 |
1 files changed, 159 insertions, 1 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index e119839218..c230595b15 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -109,5 +109,163 @@ func (g *Graph) Operation(name string) *Operation { if cop == nil { return nil } - return &Operation{cop,g} + return &Operation{cop, g} +} + +// OpSpec is the specification of an Operation to be added to a Graph +// (using Graph.AddOperation). +type OpSpec struct { + // Type of the operation (e.g., "Add", "MatMul"). + Type string + + // Name by which the added operation will be referred to in the Graph. + // If omitted, defaults to Type. + Name string + + // Inputs to this operation, which in turn must be outputs + // of other operations already added to the Graph. + // + // An operation may have multiple inputs with individual inputs being + // either a single tensor produced by another operation or a list of + // tensors produced by multiple operations. For example, the "Concat" + // operation takes two inputs: (1) the dimension along which to + // concatenate and (2) a list of tensors to concatenate. Thus, for + // Concat, len(Input) must be 2, with the first element being an Output + // and the second being an OutputList. + Input []Input + + // Map from attribute name to its value that will be attached to this + // operation. + Attrs map[string]interface{} + + // Other possible fields: Device, ColocateWith, ControlInputs. +} + +// AddOperation adds an operation to g. +func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { + if args.Name == "" { + args.Name = args.Type + } + cname := C.CString(args.Name) + ctype := C.CString(args.Type) + cdesc := C.TF_NewOperation(g.c, ctype, cname) + C.free(unsafe.Pointer(cname)) + C.free(unsafe.Pointer(ctype)) + + for _, in := range args.Input { + switch in := in.(type) { + case Output: + C.TF_AddInput(cdesc, in.c()) + case OutputList: + size := len(in) + list := make([]C.TF_Port, size) + for i, v := range in { + list[i] = v.c() + } + C.TF_AddInputList(cdesc, &list[0], C.int(size)) + } + } + status := newStatus() + for name, value := range args.Attrs { + if err := setAttr(cdesc, status, name, value); err != nil { + // Memory leak here as the TF_OperationDescription + // object will not be cleaned up. At the time of this + // writing, this was next to impossible since it + // required value to be a string tensor with + // incorrectly encoded strings. Given this rarity, live + // with the memory leak. If it becomes a real problem, + // consider adding a TF_DeleteOperationDescription + // function to the C API. + return nil, fmt.Errorf("%v (memory will be leaked)", err) + } + } + op := &Operation{ + c: C.TF_FinishOperation(cdesc, status.c), + g: g, + } + return op, status.Err() +} + +func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { + cAttrName := C.CString(name) + defer C.free(unsafe.Pointer(cAttrName)) + switch value := value.(type) { + case string: + cstr := C.CString(value) + C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.int(len(value))) + C.free(unsafe.Pointer(cstr)) + case []string: + size := len(value) + list := make([]unsafe.Pointer, size) + lens := make([]C.int, size) + for i, s := range value { + list[i] = unsafe.Pointer(C.CString(s)) + lens[i] = C.int(len(s)) + } + C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size)) + for _, s := range list { + C.free(s) + } + case int64: + C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value)) + case []int64: + size := len(value) + list := make([]C.int64_t, size) + for i, v := range value { + list[i] = C.int64_t(v) + } + C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size)) + case float32: + C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value)) + case []float32: + size := len(value) + list := make([]C.float, size) + for i, v := range value { + list[i] = C.float(v) + } + C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size)) + case bool: + v := C.uchar(0) + if value { + v = 1 + } + C.TF_SetAttrBool(cdesc, cAttrName, v) + case []bool: + size := len(value) + list := make([]C.uchar, size) + for i, v := range value { + if v { + list[i] = 1 + } + } + C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size)) + case DataType: + C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value)) + case []DataType: + list := (*C.TF_DataType)(&value[0]) + C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value))) + case *Tensor: + C.TF_SetAttrTensor(cdesc, cAttrName, value.c(), status.c) + if err := status.Err(); err != nil { + return fmt.Errorf("bad value for attribute %q: %v", name, err) + } + case []*Tensor: + size := len(value) + list := make([]*C.TF_Tensor, size) + for i, v := range value { + list[i] = v.c() + } + C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c) + if err := status.Err(); err != nil { + return fmt.Errorf("bad value for attribute %q: %v", name, err) + } + default: + // Shapes can be done, but will require that it be + // distinguishable from []int64. Which is fine, it + // probably makes sense to define a Shape type anyway, + // since that should handle partially known shapes as + // well and hide the special meaning of -1? + return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value) + } + return nil } |