diff options
-rw-r--r-- | tensorflow/go/graph.go | 8 | ||||
-rw-r--r-- | tensorflow/go/op/op_test.go | 25 | ||||
-rw-r--r-- | tensorflow/go/operation.go | 5 |
3 files changed, 34 insertions, 4 deletions
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go index e65619e80b..46c600eab1 100644 --- a/tensorflow/go/graph.go +++ b/tensorflow/go/graph.go @@ -185,11 +185,11 @@ func (g *Graph) AddOperation(args OpSpec) (*Operation, error) { return nil, fmt.Errorf("%v (memory will be leaked)", err) } } - op := &Operation{ - c: C.TF_FinishOperation(cdesc, status.c), - g: g, + c := C.TF_FinishOperation(cdesc, status.c) + if err := status.Err(); err != nil { + return nil, err } - return op, status.Err() + return &Operation{c, g}, nil } func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, value interface{}) error { diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go index 65877dca96..2451ba3606 100644 --- a/tensorflow/go/op/op_test.go +++ b/tensorflow/go/op/op_test.go @@ -19,6 +19,7 @@ limitations under the License. package op import ( + "strings" "testing" tf "github.com/tensorflow/tensorflow/tensorflow/go" @@ -33,3 +34,27 @@ func TestPlaceholder(t *testing.T) { t.Fatal(err) } } + +func TestAddOperationFailure(t *testing.T) { + // Inspired from https://github.com/tensorflow/tensorflow/issues/9931 + s := NewScope() + + resize := ResizeArea(s, Placeholder(s, tf.Float), Const(s, []int64{80, 80})) + if err := s.Err(); err == nil { + t.Fatal("ResizeArea expects an int32 Tensor for size, should fail when an int64 is provided") + } + // And any use of resize should panic with an error message more informative than SIGSEGV + defer func() { + r := recover() + if r == nil { + return + } + s, ok := r.(string) + if ok && strings.Contains(s, "see Scope.Err() for details") { + return + } + t.Errorf("Expected panic string to Scope.Err(), found %T: %q", r, r) + }() + _ = resize.Shape() + t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created") +} diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go index e8f67c4f73..8fcad61f4c 100644 --- a/tensorflow/go/operation.go +++ b/tensorflow/go/operation.go @@ -113,6 +113,11 @@ func (p Output) Shape() Shape { } func (p Output) c() C.TF_Output { + if p.Op == nil { + // Attempt to provide a more useful panic message than "nil + // pointer dereference". + panic("nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.") + } return C.TF_Output{oper: p.Op.c, index: C.int(p.Index)} } |