diff options
author | Tristan Rice <rice@fn.lc> | 2018-06-11 13:06:04 -0700 |
---|---|---|
committer | Tristan Rice <rice@fn.lc> | 2018-06-12 14:50:39 -0700 |
commit | 8eba32b6c4b259c39097b8b308532b8419d8c151 (patch) | |
tree | c973f7beadd2c82ce4887de88f613c84c20ad069 /tensorflow/go/operation.go | |
parent | a4b390bffbcb01d8f57f25c007277d457f752a69 (diff) |
tensorflow/go: add operation Input methods + tests
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r-- | tensorflow/go/operation.go | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 8fcad61f4c..baaac41f4e 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output { return Output{op, i} } +// NumInputs returns the number of inputs of op. +func (op *Operation) NumInputs() int { + return int(C.TF_OperationNumInputs(op.c)) +} + // Output represents one of the outputs of an operation in the graph. Has a // DataType (and eventually a Shape). May be passed as an input argument to a // function for adding operations to a graph, or to a Session's Run() method to @@ -123,6 +128,64 @@ func (p Output) c() C.TF_Output { func (p Output) canBeAnInput() {} +// Consumers returns the inputs that consume this output. +func (p Output) Consumers() []Consumer { + max := int(C.TF_OperationOutputNumConsumers(p.c())) + inputs := make([]C.TF_Input, max) + n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max)) + inputs = inputs[:int(n)] + + var consumers []Consumer + for _, consumer := range inputs { + consumers = append(consumers, Consumer{ + Index: int(consumer.index), + Op: &Operation{ + c: consumer.oper, + g: p.Op.g, + }, + }) + } + + return consumers +} + +// Consumer identifies a specific input of an operation that consumes the output +// of another operation. +type Consumer struct { + // Op is the Operation that is consuming the output of another operation. + Op *Operation + + // Index is the index of the input within Op that the output of another + // operation is connected to. + Index int +} + +func (p Consumer) c() C.TF_Input { + if p.Op == nil { + // Attempt to provide a more useful panic message than "nil + // pointer dereference". + panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers") + } + return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)} +} + +// DataType returns the type of the input. +func (p Consumer) DataType() DataType { + return DataType(C.TF_OperationInputType(p.c())) +} + +// Producer returns the Output that is connected to this Consumer. +func (p Consumer) Producer() Output { + output := C.TF_OperationInput(p.c()) + return Output{ + Op: &Operation{ + c: output.oper, + g: p.Op.g, + }, + Index: int(output.index), + } +} + // Input is the interface for specifying inputs to an operation being added to // a Graph. // |