diff options
author | 2016-09-28 10:37:31 -0800 | |
---|---|---|
committer | 2016-09-28 11:48:45 -0700 | |
commit | 7cacfdf03872c36cfc3f43e75e8c342234160e3a (patch) | |
tree | 0856b1decb0a1f36a68450c9a5d9b825de999c06 /tensorflow/go/operation.go | |
parent | b410f52f6616e281d8274317938a258605ad993a (diff) |
go: Add an example.
Add an example (that will appear in Go doc) for a real use of the Go TensorFlow
APIs - using a pre-defined image recognition model for inference.
While at it a couple of minor tweaks:
- NewSession now accepts 'nil' for options as the documentation says it does
- Convenience accessors for the Outputs of an Operation
- Ability to extract (possibly partial) shapes from an Output
Another step towards #10
Change: 134560938
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r-- | tensorflow/go/operation.go | 50 |
1 files changed, 49 insertions, 1 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index 0f14ea1bef..59c5c2a2b0 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -17,7 +17,11 @@ package tensorflow // #include <stdlib.h> // #include "tensorflow/c/c_api.h" import "C" -import "unsafe" + +import ( + "errors" + "unsafe" +) // Operation that has been added to the graph. type Operation struct { @@ -37,6 +41,18 @@ func (op *Operation) Type() string { return C.GoString(C.TF_OperationOpType(op.c)) } +// NumOutputs returns the number of outputs of op. +func (op *Operation) NumOutputs() int { + return int(C.TF_OperationNumOutputs(op.c)) +} + +// Output returns the i-th output of op. +// +// REQUIRES: 0 <= i < op.NumOutputs() +func (op *Operation) Output(i int) Output { + return Output{op, i} +} + // Output represents one of the outputs of an operation in the graph. Has a // DataType (and eventually a Shape). May be passed as an input argument to a // function for adding operations to a graph, or to a Session's Run() method to @@ -49,6 +65,38 @@ type Output struct { Index int } +// Shape returns the (possibly incomplete) shape of the tensor produced p. +// +// Returns a slice of length 0 if the tensor is a scalar. Returns a slice +// where shape[i] is the size of the i-th dimension of the tensor, or -1 if the +// size of that dimension is not known. +// +// Returns an error if the number of dimensions of the tensor is not known. +func (p Output) Shape() (shape []int64, err error) { + status := newStatus() + port := p.c() + ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c) + if err := status.Err(); err != nil { + return nil, err + } + if ndims < 0 { + return nil, errors.New("unknown number of dimensions") + } + if ndims == 0 { + return nil, nil + } + dims := make([]C.int64_t, ndims) + C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c) + if err := status.Err(); err != nil { + return nil, err + } + ret := make([]int64, ndims) + for i := 0; i < int(ndims); i++ { + ret[i] = int64(dims[i]) + } + return ret, nil +} + func (p *Output) c() C.TF_Port { return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)} } |