diff options
author | 2016-10-27 11:03:48 -0800 | |
---|---|---|
committer | 2016-10-27 12:21:17 -0700 | |
commit | 7351a21714f467eb9d440703001876616d02e0fd (patch) | |
tree | 045ef82b2337c660c0440a552923105ff38016f6 /tensorflow/go/genop | |
parent | 414eeebc639eb75009ea5bbbac32a1f8a275dd30 (diff) |
go: Change generated op function API.
Errors during graph construction are held in the Scope.
This makes the op construction code mode compact and nestable
as errors do not _need_ to be handled on every op addition.
To help ensure that the client does not miss the error completely,
the Scope is treated as a "builder" of Graphs, and once the Graph
is extracted from the Scope (using Scope.Finalize), the Scope is
rendered useless.
To help tracing failures to the op, the error stores the stacktrace
pointing to when the error occurred, which will help in identifying
the precise operation that failed.
To use scopes to enhance existing Graphs, the idea is to add
a:
NewScopeWithGraph(*tf.Graph) *Scope
function, but that is not included in this change.
Change: 137423318
Diffstat (limited to 'tensorflow/go/genop')
-rw-r--r-- | tensorflow/go/genop/internal/genop.go | 40 | ||||
-rw-r--r-- | tensorflow/go/genop/internal/genop_test.go | 94 |
2 files changed, 104 insertions, 30 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index fdc55f5ebc..5d5aa26992 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -244,10 +244,14 @@ func {{.Op.Name}} {{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} ) -{{- /* Construct outputs: len(OpDef.OutputArg) + 1 (for error) */ -}} +{{- /* Construct outputs: len(OpDef.OutputArg) */ -}} -({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}} -{{if .Op.OutputArg}}, {{end}}err error) { +{{if .Op.OutputArg -}} +({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}) +{{- end }} { + if scope.Err() != nil { + return + } {{if .HasAttrs -}} attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}} {{if .OptionalAttrs -}} @@ -262,25 +266,37 @@ func {{.Op.Name}} Input: []tf.Input{ {{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}} }, - {{end}} - {{- if .HasAttrs}}Attrs: attrs,{{end}} + {{- end}} + {{- if .HasAttrs}} + Attrs: attrs, + {{- end}} } - {{if .Op.OutputArg}}op, err :={{else}}_, err ={{end}} scope.Graph().AddOperation(opspec) + {{- if .Op.OutputArg}} {{- if .HasListOutput}} + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } var idx int + var err error {{- range $i, $a := .Op.OutputArg}} {{- if IsListArg $a}} if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil { - return {{range $.Op.OutputArg}}{{Identifier .Name}}, {{end}}err + scope.UpdateErr({{printf "%q" $.Op.Name}}, err) + return } {{- else }} {{Identifier .Name}} = op.Output(idx) - {{- end }} - {{- end }} - return {{range .Op.OutputArg}}{{Identifier .Name}}, {{end}}err + {{- end }}{{- /* if IsListArg */}} + {{- end }}{{- /* range .Op.OutputArg */}} + return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}} + {{- else }} + op := scope.AddOperation(opspec) + return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} + {{- end }}{{- /* if .HasListOutput */}} {{- else }} - return {{range $i, $a := .Op.OutputArg}}op.Output({{$i}}), {{end}}err - {{- end }} + scope.AddOperation(opspec) + {{- end }}{{- /* if .Op.OutputArg */}} } `)) ) diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go index dade7ce48f..b3bcd9db05 100644 --- a/tensorflow/go/genop/internal/genop_test.go +++ b/tensorflow/go/genop/internal/genop_test.go @@ -39,12 +39,14 @@ summary: "No. Op." `, wanted: ` // No. Op. -func NoOp(scope *Scope) (err error) { +func NoOp(scope *Scope) { + if scope.Err() != nil { + return + } opspec := tf.OpSpec{ Type: "NoOp", } - _, err = scope.Graph().AddOperation(opspec) - return err + scope.AddOperation(opspec) } `, }, @@ -81,15 +83,18 @@ description: "Blah blah", // Returns x + y element-wise. // // Blah blah -func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output, err error) { +func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } opspec := tf.OpSpec{ Type: "Add", Input: []tf.Input{ x, y, }, } - op, err := scope.Graph().AddOperation(opspec) - return op.Output(0), err + op := scope.AddOperation(opspec) + return op.Output(0) } `, }, @@ -117,7 +122,10 @@ summary: "Cast x of type SrcT to y of DstT." `, wanted: ` // Cast x of type SrcT to y of DstT. -func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output, err error) { +func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) { + if scope.Err() != nil { + return + } attrs := map[string]interface{}{"DstT": DstT} opspec := tf.OpSpec{ Type: "Cast", @@ -126,8 +134,8 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output, err error) }, Attrs: attrs, } - op, err := scope.Graph().AddOperation(opspec) - return op.Output(0), err + op := scope.AddOperation(opspec) + return op.Output(0) } `, }, @@ -218,7 +226,10 @@ func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr { // contents: 0-D. The JPEG-encoded image. // // Returns 3-D with shape [height, width, channels] -func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output, err error) { +func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) { + if scope.Err() != nil { + return + } attrs := map[string]interface{}{} for _, a := range optional { a(attrs) @@ -230,8 +241,47 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i }, Attrs: attrs, } - op, err := scope.Graph().AddOperation(opspec) - return op.Output(0), err + op := scope.AddOperation(opspec) + return op.Output(0) +} +`, + }, + { + tag: "MultipleOutputs", + opdef: ` +name: "TwoOutputs" +input_arg: < + name: "input" + type_attr: "T" +> +output_arg < + name: "x" + type_attr: "T" +> +output_arg < + name: "y" + type_attr: "T" +> +attr: < + name: "T" + type: "type" +> +summary: "Op that produces multiple outputs" +`, + wanted: ` +// Op that produces multiple outputs +func TwoOutputs(scope *Scope, input tf.Output) (x tf.Output, y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "TwoOutputs", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0), op.Output(1) } `, }, @@ -290,7 +340,10 @@ func ShapeNOutType(value tf.DataType) ShapeNAttr { // Returns shape of tensors. // // Some description here. -func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output, err error) { +func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) { + if scope.Err() != nil { + return + } attrs := map[string]interface{}{} for _, a := range optional { a(attrs) @@ -302,12 +355,17 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t }, Attrs: attrs, } - op, err := scope.Graph().AddOperation(opspec) + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } var idx int + var err error if output, idx, err = makeOutputList(op, idx, "output"); err != nil { - return output, err + scope.UpdateErr("ShapeN", err) + return } - return output, err + return output } `, }, @@ -325,11 +383,11 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t } got, err := format.Source(buf.Bytes()) if err != nil { - t.Fatal(err) + t.Fatalf("Unable to format: %v\n%s", err, buf.Bytes()) } want, err := format.Source([]byte(test.wanted)) if err != nil { - t.Fatal(err) + t.Fatalf("Unable to format: %v\n%s", err, test.wanted) } if !bytes.Equal(got, want) { t.Fatalf("Got:\n%s\nWant:\n%s\n", got, want) |