aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-09-28 10:37:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-28 11:48:45 -0700
commit7cacfdf03872c36cfc3f43e75e8c342234160e3a (patch)
tree0856b1decb0a1f36a68450c9a5d9b825de999c06 /tensorflow/go/operation.go
parentb410f52f6616e281d8274317938a258605ad993a (diff)
go: Add an example.
Add an example (that will appear in Go doc) for a real use of the Go TensorFlow APIs - using a pre-defined image recognition model for inference. While at it a couple of minor tweaks: - NewSession now accepts 'nil' for options as the documentation says it does - Convenience accessors for the Outputs of an Operation - Ability to extract (possibly partial) shapes from an Output Another step towards #10 Change: 134560938
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r--tensorflow/go/operation.go50
1 files changed, 49 insertions, 1 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index 0f14ea1bef..59c5c2a2b0 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -17,7 +17,11 @@ package tensorflow
// #include <stdlib.h>
// #include "tensorflow/c/c_api.h"
import "C"
-import "unsafe"
+
+import (
+ "errors"
+ "unsafe"
+)
// Operation that has been added to the graph.
type Operation struct {
@@ -37,6 +41,18 @@ func (op *Operation) Type() string {
return C.GoString(C.TF_OperationOpType(op.c))
}
+// NumOutputs returns the number of outputs of op.
+func (op *Operation) NumOutputs() int {
+ return int(C.TF_OperationNumOutputs(op.c))
+}
+
+// Output returns the i-th output of op.
+//
+// REQUIRES: 0 <= i < op.NumOutputs()
+func (op *Operation) Output(i int) Output {
+ return Output{op, i}
+}
+
// Output represents one of the outputs of an operation in the graph. Has a
// DataType (and eventually a Shape). May be passed as an input argument to a
// function for adding operations to a graph, or to a Session's Run() method to
@@ -49,6 +65,38 @@ type Output struct {
Index int
}
+// Shape returns the (possibly incomplete) shape of the tensor produced p.
+//
+// Returns a slice of length 0 if the tensor is a scalar. Returns a slice
+// where shape[i] is the size of the i-th dimension of the tensor, or -1 if the
+// size of that dimension is not known.
+//
+// Returns an error if the number of dimensions of the tensor is not known.
+func (p Output) Shape() (shape []int64, err error) {
+ status := newStatus()
+ port := p.c()
+ ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ if ndims < 0 {
+ return nil, errors.New("unknown number of dimensions")
+ }
+ if ndims == 0 {
+ return nil, nil
+ }
+ dims := make([]C.int64_t, ndims)
+ C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ ret := make([]int64, ndims)
+ for i := 0; i < int(ndims); i++ {
+ ret[i] = int64(dims[i])
+ }
+ return ret, nil
+}
+
func (p *Output) c() C.TF_Port {
return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)}
}