aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-07 09:56:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-07 11:04:39 -0700
commitb5b1671163c90cf786fe0c65ca6754a85d017847 (patch)
tree313b04a766bcfbc31b44f0b6237dc0b3389c5c3b /tensorflow/go/operation.go
parent318948b28076540841f0c079f12c8f4fdd3d15f1 (diff)
go: Introduce Graph.AddOperation to add operations to the Graph.
Export an API to add operations to the graph. I also intend to use this API in a code generator that will generate Go sources files containing functions for each OpDef (and all this generated code will be in a separate package). While at it, also changed some tests to use the "sub-tests" feature in Go 1.7 (https://blog.golang.org/subtests) Another step in the journey of #10 Change: 135493412
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r--tensorflow/go/operation.go63
1 files changed, 18 insertions, 45 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 59c5c2a2b0..8f4fee1cbe 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -20,7 +20,6 @@ import "C"
import (
"errors"
- "unsafe"
)
// Operation that has been added to the graph.
@@ -47,8 +46,6 @@ func (op *Operation) NumOutputs() int {
}
// Output returns the i-th output of op.
-//
-// REQUIRES: 0 <= i < op.NumOutputs()
func (op *Operation) Output(i int) Output {
return Output{op, i}
}
@@ -97,52 +94,28 @@ func (p Output) Shape() (shape []int64, err error) {
return ret, nil
}
-func (p *Output) c() C.TF_Port {
+func (p Output) c() C.TF_Port {
return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)}
}
-// opBuilder is for use by the generated op code to create new Operations.
-// Build() must be called for any in-progress Operation, or else we leak.
-type opBuilder struct {
- c *C.TF_OperationDescription
- // A reference to the Graph to prevent it from being GCed while
- // the opBuilder is still alive.
- g *Graph
-}
-
-func newOpBuilder(g *Graph, typ string, name string) *opBuilder {
- opType := C.CString(typ)
- opName := C.CString(name)
- b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName), g: g}
- C.free(unsafe.Pointer(opType))
- C.free(unsafe.Pointer(opName))
- return b
-}
-
-func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error {
- status := newStatus()
- attrName := C.CString(name)
- C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c)
- C.free(unsafe.Pointer(attrName))
- return status.Err()
-}
+func (p Output) canBeAnInput() {}
-func (b *opBuilder) SetAttrType(name string, typ DataType) {
- attrName := C.CString(name)
- C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ))
- C.free(unsafe.Pointer(attrName))
+// Input is the interface for specifying inputs to an operation being added to
+// a Graph.
+//
+// Operations can have multiple inputs, each of which could be either a tensor
+// produced by another operation (an Output object), or a list of tensors
+// produced by other operations (an OutputList). Thus, this interface is
+// implemented by both Output and OutputList.
+//
+// See OpSpec.Input for more information.
+type Input interface {
+ // Unexported to preclude implementations outside this package.
+ canBeAnInput()
}
-func (b *opBuilder) AddInput(port Output) {
- C.TF_AddInput(b.c, port.c())
-}
+// OutputList represents a list of Outputs that can be provided as input to
+// another operation.
+type OutputList []Output
-func (b *opBuilder) Build() (*Operation, error) {
- status := newStatus()
- op := &Operation{c: C.TF_FinishOperation(b.c, status.c), g: b.g}
- if err := status.Err(); err != nil {
- return nil, err
- }
- b.c = nil
- return op, nil
-}
+func (l OutputList) canBeAnInput() {}