aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/genop
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-27 11:03:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 12:21:17 -0700
commit7351a21714f467eb9d440703001876616d02e0fd (patch)
tree045ef82b2337c660c0440a552923105ff38016f6 /tensorflow/go/genop
parent414eeebc639eb75009ea5bbbac32a1f8a275dd30 (diff)
go: Change generated op function API.
Errors during graph construction are held in the Scope. This makes the op construction code mode compact and nestable as errors do not _need_ to be handled on every op addition. To help ensure that the client does not miss the error completely, the Scope is treated as a "builder" of Graphs, and once the Graph is extracted from the Scope (using Scope.Finalize), the Scope is rendered useless. To help tracing failures to the op, the error stores the stacktrace pointing to when the error occurred, which will help in identifying the precise operation that failed. To use scopes to enhance existing Graphs, the idea is to add a: NewScopeWithGraph(*tf.Graph) *Scope function, but that is not included in this change. Change: 137423318
Diffstat (limited to 'tensorflow/go/genop')
-rw-r--r--tensorflow/go/genop/internal/genop.go40
-rw-r--r--tensorflow/go/genop/internal/genop_test.go94
2 files changed, 104 insertions, 30 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go
index fdc55f5ebc..5d5aa26992 100644
--- a/tensorflow/go/genop/internal/genop.go
+++ b/tensorflow/go/genop/internal/genop.go
@@ -244,10 +244,14 @@ func {{.Op.Name}}
{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
)
-{{- /* Construct outputs: len(OpDef.OutputArg) + 1 (for error) */ -}}
+{{- /* Construct outputs: len(OpDef.OutputArg) */ -}}
-({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}
-{{if .Op.OutputArg}}, {{end}}err error) {
+{{if .Op.OutputArg -}}
+({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}})
+{{- end }} {
+ if scope.Err() != nil {
+ return
+ }
{{if .HasAttrs -}}
attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}}
{{if .OptionalAttrs -}}
@@ -262,25 +266,37 @@ func {{.Op.Name}}
Input: []tf.Input{
{{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}}
},
- {{end}}
- {{- if .HasAttrs}}Attrs: attrs,{{end}}
+ {{- end}}
+ {{- if .HasAttrs}}
+ Attrs: attrs,
+ {{- end}}
}
- {{if .Op.OutputArg}}op, err :={{else}}_, err ={{end}} scope.Graph().AddOperation(opspec)
+ {{- if .Op.OutputArg}}
{{- if .HasListOutput}}
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
var idx int
+ var err error
{{- range $i, $a := .Op.OutputArg}}
{{- if IsListArg $a}}
if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
- return {{range $.Op.OutputArg}}{{Identifier .Name}}, {{end}}err
+ scope.UpdateErr({{printf "%q" $.Op.Name}}, err)
+ return
}
{{- else }}
{{Identifier .Name}} = op.Output(idx)
- {{- end }}
- {{- end }}
- return {{range .Op.OutputArg}}{{Identifier .Name}}, {{end}}err
+ {{- end }}{{- /* if IsListArg */}}
+ {{- end }}{{- /* range .Op.OutputArg */}}
+ return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier .Name}}{{end}}
+ {{- else }}
+ op := scope.AddOperation(opspec)
+ return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}}
+ {{- end }}{{- /* if .HasListOutput */}}
{{- else }}
- return {{range $i, $a := .Op.OutputArg}}op.Output({{$i}}), {{end}}err
- {{- end }}
+ scope.AddOperation(opspec)
+ {{- end }}{{- /* if .Op.OutputArg */}}
}
`))
)
diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go
index dade7ce48f..b3bcd9db05 100644
--- a/tensorflow/go/genop/internal/genop_test.go
+++ b/tensorflow/go/genop/internal/genop_test.go
@@ -39,12 +39,14 @@ summary: "No. Op."
`,
wanted: `
// No. Op.
-func NoOp(scope *Scope) (err error) {
+func NoOp(scope *Scope) {
+ if scope.Err() != nil {
+ return
+ }
opspec := tf.OpSpec{
Type: "NoOp",
}
- _, err = scope.Graph().AddOperation(opspec)
- return err
+ scope.AddOperation(opspec)
}
`,
},
@@ -81,15 +83,18 @@ description: "Blah blah",
// Returns x + y element-wise.
//
// Blah blah
-func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output, err error) {
+func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
opspec := tf.OpSpec{
Type: "Add",
Input: []tf.Input{
x, y,
},
}
- op, err := scope.Graph().AddOperation(opspec)
- return op.Output(0), err
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
`,
},
@@ -117,7 +122,10 @@ summary: "Cast x of type SrcT to y of DstT."
`,
wanted: `
// Cast x of type SrcT to y of DstT.
-func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output, err error) {
+func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
attrs := map[string]interface{}{"DstT": DstT}
opspec := tf.OpSpec{
Type: "Cast",
@@ -126,8 +134,8 @@ func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output, err error)
},
Attrs: attrs,
}
- op, err := scope.Graph().AddOperation(opspec)
- return op.Output(0), err
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
`,
},
@@ -218,7 +226,10 @@ func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
// contents: 0-D. The JPEG-encoded image.
//
// Returns 3-D with shape [height, width, channels]
-func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output, err error) {
+func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
@@ -230,8 +241,47 @@ func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (i
},
Attrs: attrs,
}
- op, err := scope.Graph().AddOperation(opspec)
- return op.Output(0), err
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+`,
+ },
+ {
+ tag: "MultipleOutputs",
+ opdef: `
+name: "TwoOutputs"
+input_arg: <
+ name: "input"
+ type_attr: "T"
+>
+output_arg <
+ name: "x"
+ type_attr: "T"
+>
+output_arg <
+ name: "y"
+ type_attr: "T"
+>
+attr: <
+ name: "T"
+ type: "type"
+>
+summary: "Op that produces multiple outputs"
+`,
+ wanted: `
+// Op that produces multiple outputs
+func TwoOutputs(scope *Scope, input tf.Output) (x tf.Output, y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TwoOutputs",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
}
`,
},
@@ -290,7 +340,10 @@ func ShapeNOutType(value tf.DataType) ShapeNAttr {
// Returns shape of tensors.
//
// Some description here.
-func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output, err error) {
+func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
attrs := map[string]interface{}{}
for _, a := range optional {
a(attrs)
@@ -302,12 +355,17 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
},
Attrs: attrs,
}
- op, err := scope.Graph().AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
var idx int
+ var err error
if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
- return output, err
+ scope.UpdateErr("ShapeN", err)
+ return
}
- return output, err
+ return output
}
`,
},
@@ -325,11 +383,11 @@ func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []t
}
got, err := format.Source(buf.Bytes())
if err != nil {
- t.Fatal(err)
+ t.Fatalf("Unable to format: %v\n%s", err, buf.Bytes())
}
want, err := format.Source([]byte(test.wanted))
if err != nil {
- t.Fatal(err)
+ t.Fatalf("Unable to format: %v\n%s", err, test.wanted)
}
if !bytes.Equal(got, want) {
t.Fatalf("Got:\n%s\nWant:\n%s\n", got, want)