diff options
author | Asim Shankar <ashankar@google.com> | 2017-01-24 16:13:44 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-24 16:26:57 -0800 |
commit | c995ee586cfcce29a10b6e05140f1cc7c6c13a16 (patch) | |
tree | 526cf540971cde16e63356245659a2aee44e635b /tensorflow/go/operation.go | |
parent | 4ac5ed3184873d530dc249a74794a55229e85e0f (diff) |
Go: Output.Shape now returns a Shape object.
Output.Shape may be only partially known, hence the recently introduced Shape
type is a more appropriate return value.
Change: 145481782
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r-- | tensorflow/go/operation.go | 31 |
1 files changed, 13 insertions, 18 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index eb8614653e..df41c40a2b 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -18,10 +18,7 @@ package tensorflow // #include "tensorflow/c/c_api.h" import "C" -import ( - "errors" - "unsafe" -) +import "unsafe" // Operation that has been added to the graph. type Operation struct { @@ -79,35 +76,33 @@ type Output struct { } // Shape returns the (possibly incomplete) shape of the tensor produced p. -// -// Returns a slice of length 0 if the tensor is a scalar. Returns a slice -// where shape[i] is the size of the i-th dimension of the tensor, or -1 if the -// size of that dimension is not known. -// -// Returns an error if the number of dimensions of the tensor is not known. -func (p Output) Shape() (shape []int64, err error) { +func (p Output) Shape() Shape { status := newStatus() port := p.c() ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c) if err := status.Err(); err != nil { - return nil, err + // This should not be possible since an error only occurs if + // the operation does not belong to the graph. It should not + // be possible to construct such an Operation object. + return Shape{} } if ndims < 0 { - return nil, errors.New("unknown number of dimensions") + return Shape{} } if ndims == 0 { - return nil, nil + return ScalarShape() } dims := make([]C.int64_t, ndims) C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c) if err := status.Err(); err != nil { - return nil, err + // Same as above, should not be possible. + return Shape{} } - ret := make([]int64, ndims) + ret := Shape{dims: make([]int64, ndims)} for i := 0; i < int(ndims); i++ { - ret[i] = int64(dims[i]) + ret.dims[i] = int64(dims[i]) } - return ret, nil + return ret } func (p Output) c() C.TF_Output { |