aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation.go
diff options
context:
space:
mode:
authorGravatar Tristan Rice <rice@fn.lc>2018-06-11 13:06:04 -0700
committerGravatar Tristan Rice <rice@fn.lc>2018-06-12 14:50:39 -0700
commit8eba32b6c4b259c39097b8b308532b8419d8c151 (patch)
treec973f7beadd2c82ce4887de88f613c84c20ad069 /tensorflow/go/operation.go
parenta4b390bffbcb01d8f57f25c007277d457f752a69 (diff)
tensorflow/go: add operation Input methods + tests
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r--tensorflow/go/operation.go63
1 files changed, 63 insertions, 0 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 8fcad61f4c..baaac41f4e 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -65,6 +65,11 @@ func (op *Operation) Output(i int) Output {
return Output{op, i}
}
+// NumInputs returns the number of inputs of op.
+func (op *Operation) NumInputs() int {
+ return int(C.TF_OperationNumInputs(op.c))
+}
+
// Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to
@@ -123,6 +128,64 @@ func (p Output) c() C.TF_Output {
func (p Output) canBeAnInput() {}
+// Consumers returns the inputs that consume this output.
+func (p Output) Consumers() []Consumer {
+ max := int(C.TF_OperationOutputNumConsumers(p.c()))
+ inputs := make([]C.TF_Input, max)
+ n := C.TF_OperationOutputConsumers(p.c(), (*C.TF_Input)(unsafe.Pointer(&inputs[0])), C.int(max))
+ inputs = inputs[:int(n)]
+
+ var consumers []Consumer
+ for _, consumer := range inputs {
+ consumers = append(consumers, Consumer{
+ Index: int(consumer.index),
+ Op: &Operation{
+ c: consumer.oper,
+ g: p.Op.g,
+ },
+ })
+ }
+
+ return consumers
+}
+
+// Consumer identifies a specific input of an operation that consumes the output
+// of another operation.
+type Consumer struct {
+ // Op is the Operation that is consuming the output of another operation.
+ Op *Operation
+
+ // Index is the index of the input within Op that the output of another
+ // operation is connected to.
+ Index int
+}
+
+func (p Consumer) c() C.TF_Input {
+ if p.Op == nil {
+ // Attempt to provide a more useful panic message than "nil
+ // pointer dereference".
+ panic("nil-Operation. Consumer objects should only be created by a call to Output.Consumers")
+ }
+ return C.TF_Input{oper: p.Op.c, index: C.int(p.Index)}
+}
+
+// DataType returns the type of the input.
+func (p Consumer) DataType() DataType {
+ return DataType(C.TF_OperationInputType(p.c()))
+}
+
+// Producer returns the Output that is connected to this Consumer.
+func (p Consumer) Producer() Output {
+ output := C.TF_OperationInput(p.c())
+ return Output{
+ Op: &Operation{
+ c: output.oper,
+ g: p.Op.g,
+ },
+ Index: int(output.index),
+ }
+}
+
// Input is the interface for specifying inputs to an operation being added to
// a Graph.
//