aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/graph.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-07 09:56:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 11:04:39 -0700
commitb5b1671163c90cf786fe0c65ca6754a85d017847 (patch)
tree313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/graph.go
parent318948b28076540841f0c079f12c8f4fdd3d15f1 (diff)
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph. I also intend to use this API in a code generator that will generate Go sources files containing functions for each OpDef (and all this generated code will be in a separate package). While at it, also changed some tests to use the "sub-tests" feature in Go 1.7 (https://blog.golang.org/subtests) Another step in the journey of #10 Change: 135493412
Diffstat (limited to 'tensorflow/go/graph.go')
-rw-r--r--tensorflow/go/graph.go160
1 files changed, 159 insertions, 1 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index e119839218..c230595b15 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -109,5 +109,163 @@ func (g *Graph) Operation(name string) *Operation {
if cop == nil {
return nil
}
- return &Operation{cop,g}
+ return &Operation{cop, g}
+}
+
+// OpSpec is the specification of an Operation to be added to a Graph
+// (using Graph.AddOperation).
+type OpSpec struct {
+ // Type of the operation (e.g., "Add", "MatMul").
+ Type string
+
+ // Name by which the added operation will be referred to in the Graph.
+ // If omitted, defaults to Type.
+ Name string
+
+ // Inputs to this operation, which in turn must be outputs
+ // of other operations already added to the Graph.
+ //
+ // An operation may have multiple inputs with individual inputs being
+ // either a single tensor produced by another operation or a list of
+ // tensors produced by multiple operations. For example, the "Concat"
+ // operation takes two inputs: (1) the dimension along which to
+ // concatenate and (2) a list of tensors to concatenate. Thus, for
+ // Concat, len(Input) must be 2, with the first element being an Output
+ // and the second being an OutputList.
+ Input []Input
+
+ // Map from attribute name to its value that will be attached to this
+ // operation.
+ Attrs map[string]interface{}
+
+ // Other possible fields: Device, ColocateWith, ControlInputs.
+}
+
+// AddOperation adds an operation to g.
+func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
+ if args.Name == "" {
+ args.Name = args.Type
+ }
+ cname := C.CString(args.Name)
+ ctype := C.CString(args.Type)
+ cdesc := C.TF_NewOperation(g.c, ctype, cname)
+ C.free(unsafe.Pointer(cname))
+ C.free(unsafe.Pointer(ctype))
+
+ for _, in := range args.Input {
+ switch in := in.(type) {
+ case Output:
+ C.TF_AddInput(cdesc, in.c())
+ case OutputList:
+ size := len(in)
+ list := make([]C.TF_Port, size)
+ for i, v := range in {
+ list[i] = v.c()
+ }
+ C.TF_AddInputList(cdesc, &list[0], C.int(size))
+ }
+ }
+ status := newStatus()
+ for name, value := range args.Attrs {
+ if err := setAttr(cdesc, status, name, value); err != nil {
+ // Memory leak here as the TF_OperationDescription
+ // object will not be cleaned up. At the time of this
+ // writing, this was next to impossible since it
+ // required value to be a string tensor with
+ // incorrectly encoded strings. Given this rarity, live
+ // with the memory leak. If it becomes a real problem,
+ // consider adding a TF_DeleteOperationDescription
+ // function to the C API.
+ return nil, fmt.Errorf("%v (memory will be leaked)", err)
+ }
+ }
+ op := &Operation{
+ c: C.TF_FinishOperation(cdesc, status.c),
+ g: g,
+ }
+ return op, status.Err()
+}
+
+func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
+ cAttrName := C.CString(name)
+ defer C.free(unsafe.Pointer(cAttrName))
+ switch value := value.(type) {
+ case string:
+ cstr := C.CString(value)
+ C.TF_SetAttrString(cdesc, cAttrName, unsafe.Pointer(cstr), C.int(len(value)))
+ C.free(unsafe.Pointer(cstr))
+ case []string:
+ size := len(value)
+ list := make([]unsafe.Pointer, size)
+ lens := make([]C.int, size)
+ for i, s := range value {
+ list[i] = unsafe.Pointer(C.CString(s))
+ lens[i] = C.int(len(s))
+ }
+ C.TF_SetAttrStringList(cdesc, cAttrName, &list[0], &lens[0], C.int(size))
+ for _, s := range list {
+ C.free(s)
+ }
+ case int64:
+ C.TF_SetAttrInt(cdesc, cAttrName, C.int64_t(value))
+ case []int64:
+ size := len(value)
+ list := make([]C.int64_t, size)
+ for i, v := range value {
+ list[i] = C.int64_t(v)
+ }
+ C.TF_SetAttrIntList(cdesc, cAttrName, &list[0], C.int(size))
+ case float32:
+ C.TF_SetAttrFloat(cdesc, cAttrName, C.float(value))
+ case []float32:
+ size := len(value)
+ list := make([]C.float, size)
+ for i, v := range value {
+ list[i] = C.float(v)
+ }
+ C.TF_SetAttrFloatList(cdesc, cAttrName, &list[0], C.int(size))
+ case bool:
+ v := C.uchar(0)
+ if value {
+ v = 1
+ }
+ C.TF_SetAttrBool(cdesc, cAttrName, v)
+ case []bool:
+ size := len(value)
+ list := make([]C.uchar, size)
+ for i, v := range value {
+ if v {
+ list[i] = 1
+ }
+ }
+ C.TF_SetAttrBoolList(cdesc, cAttrName, &list[0], C.int(size))
+ case DataType:
+ C.TF_SetAttrType(cdesc, cAttrName, C.TF_DataType(value))
+ case []DataType:
+ list := (*C.TF_DataType)(&value[0])
+ C.TF_SetAttrTypeList(cdesc, cAttrName, list, C.int(len(value)))
+ case *Tensor:
+ C.TF_SetAttrTensor(cdesc, cAttrName, value.c(), status.c)
+ if err := status.Err(); err != nil {
+ return fmt.Errorf("bad value for attribute %q: %v", name, err)
+ }
+ case []*Tensor:
+ size := len(value)
+ list := make([]*C.TF_Tensor, size)
+ for i, v := range value {
+ list[i] = v.c()
+ }
+ C.TF_SetAttrTensorList(cdesc, cAttrName, &list[0], C.int(size), status.c)
+ if err := status.Err(); err != nil {
+ return fmt.Errorf("bad value for attribute %q: %v", name, err)
+ }
+ default:
+ // Shapes can be done, but will require that it be
+ // distinguishable from []int64. Which is fine, it
+ // probably makes sense to define a Shape type anyway,
+ // since that should handle partially known shapes as
+ // well and hide the special meaning of -1?
+ return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
+ }
+ return nil
}