aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/operation.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-01-24 16:13:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-24 16:26:57 -0800
commitc995ee586cfcce29a10b6e05140f1cc7c6c13a16 (patch)
tree526cf540971cde16e63356245659a2aee44e635b /tensorflow/go/operation.go
parent4ac5ed3184873d530dc249a74794a55229e85e0f (diff)
Go: Output.Shape now returns a Shape object.
Output.Shape may be only partially known, hence the recently introduced Shape type is a more appropriate return value. Change: 145481782
Diffstat (limited to 'tensorflow/go/operation.go')
-rw-r--r--tensorflow/go/operation.go31
1 files changed, 13 insertions, 18 deletions
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index eb8614653e..df41c40a2b 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -18,10 +18,7 @@ package tensorflow
// #include "tensorflow/c/c_api.h"
import "C"
-import (
- "errors"
- "unsafe"
-)
+import "unsafe"
// Operation that has been added to the graph.
type Operation struct {
@@ -79,35 +76,33 @@ type Output struct {
}
// Shape returns the (possibly incomplete) shape of the tensor produced p.
-//
-// Returns a slice of length 0 if the tensor is a scalar. Returns a slice
-// where shape[i] is the size of the i-th dimension of the tensor, or -1 if the
-// size of that dimension is not known.
-//
-// Returns an error if the number of dimensions of the tensor is not known.
-func (p Output) Shape() (shape []int64, err error) {
+func (p Output) Shape() Shape {
status := newStatus()
port := p.c()
ndims := C.TF_GraphGetTensorNumDims(p.Op.g.c, port, status.c)
if err := status.Err(); err != nil {
- return nil, err
+ // This should not be possible since an error only occurs if
+ // the operation does not belong to the graph. It should not
+ // be possible to construct such an Operation object.
+ return Shape{}
}
if ndims < 0 {
- return nil, errors.New("unknown number of dimensions")
+ return Shape{}
}
if ndims == 0 {
- return nil, nil
+ return ScalarShape()
}
dims := make([]C.int64_t, ndims)
C.TF_GraphGetTensorShape(p.Op.g.c, port, &dims[0], ndims, status.c)
if err := status.Err(); err != nil {
- return nil, err
+ // Same as above, should not be possible.
+ return Shape{}
}
- ret := make([]int64, ndims)
+ ret := Shape{dims: make([]int64, ndims)}
for i := 0; i < int(ndims); i++ {
- ret[i] = int64(dims[i])
+ ret.dims[i] = int64(dims[i])
}
- return ret, nil
+ return ret
}
func (p Output) c() C.TF_Output {