aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/go/graph.go8
-rw-r--r--tensorflow/go/op/op_test.go25
-rw-r--r--tensorflow/go/operation.go5
3 files changed, 34 insertions, 4 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index e65619e80b..46c600eab1 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -185,11 +185,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) {
return nil, fmt.Errorf("%v (memory will be leaked)", err)
}
}
- op := &Operation{
- c: C.TF_FinishOperation(cdesc, status.c),
- g: g,
+ c := C.TF_FinishOperation(cdesc, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
}
- return op, status.Err()
+ return &Operation{c, g}, nil
}
func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error {
diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go
index 65877dca96..2451ba3606 100644
--- a/tensorflow/go/op/op_test.go
+++ b/tensorflow/go/op/op_test.go
@@ -19,6 +19,7 @@ limitations under the License.
package op
import (
+ "strings"
"testing"
tf "github.com/tensorflow/tensorflow/tensorflow/go"
@@ -33,3 +34,27 @@ func TestPlaceholder(t *testing.T) {
t.Fatal(err)
}
}
+
+func TestAddOperationFailure(t *testing.T) {
+ // Inspired from https://github.com/tensorflow/tensorflow/issues/9931
+ s := NewScope()
+
+ resize := ResizeArea(s, Placeholder(s, tf.Float), Const(s, []int64{80, 80}))
+ if err := s.Err(); err == nil {
+ t.Fatal("ResizeArea expects an int32 Tensor for size, should fail when an int64 is provided")
+ }
+ // And any use of resize should panic with an error message more informative than SIGSEGV
+ defer func() {
+ r := recover()
+ if r == nil {
+ return
+ }
+ s, ok := r.(string)
+ if ok && strings.Contains(s, "see Scope.Err() for details") {
+ return
+ }
+ t.Errorf("Expected panic string to Scope.Err(), found %T: %q", r, r)
+ }()
+ _ = resize.Shape()
+ t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created")
+}
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
index e8f67c4f73..8fcad61f4c 100644
--- a/tensorflow/go/operation.go
+++ b/tensorflow/go/operation.go
@@ -113,6 +113,11 @@ func (p Output) Shape() Shape {
}
func (p Output) c() C.TF_Output {
+ if p.Op == nil {
+ // Attempt to provide a more useful panic message than "nil
+ // pointer dereference".
+ panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.")
+ }
return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)}
}