diff options
Diffstat (limited to 'tensorflow/go/genop')
-rw-r--r-- | tensorflow/go/genop/internal/genop.go | 11 | ||||
-rw-r--r-- | tensorflow/go/genop/internal/genop_test.go | 6 |
2 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go index d9ebec0f8c..16e4d0e512 100644 --- a/tensorflow/go/genop/internal/genop.go +++ b/tensorflow/go/genop/internal/genop.go @@ -212,6 +212,10 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {{- end -}} {{- end -}} +{{- if (not .Op.OutputArg) }} +// +// Returns the created operation. +{{- else }} {{- if .DescribeOutputs}} // {{- if ((len .Op.OutputArg) eq 1) }} @@ -223,6 +227,7 @@ func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {{- end -}} {{- end -}} {{- end -}} +{{- end -}} {{- /* The function signature. @@ -244,10 +249,12 @@ func {{.Op.Name}} {{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}} ) -{{- /* Construct outputs: len(OpDef.OutputArg) */ -}} +{{- /* Construct outputs: len(OpDef.OutputArg) or a *tf.Operation */ -}} {{if .Op.OutputArg -}} ({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}) +{{- else -}} +(o *tf.Operation) {{- end }} { if scope.Err() != nil { return @@ -295,7 +302,7 @@ func {{.Op.Name}} return {{range $i, $a := .Op.OutputArg}}{{if $i}}, {{end}}op.Output({{$i}}){{end}} {{- end }}{{- /* if .HasListOutput */}} {{- else }} - scope.AddOperation(opspec) + return scope.AddOperation(opspec) {{- end }}{{- /* if .Op.OutputArg */}} } `)) diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go index c3057e9119..c66e38fce0 100644 --- a/tensorflow/go/genop/internal/genop_test.go +++ b/tensorflow/go/genop/internal/genop_test.go @@ -39,14 +39,16 @@ summary: "No. Op." `, wanted: ` // No. Op. -func NoOp(scope *Scope) { +// +// Returns the created operation. +func NoOp(scope *Scope) (o *tf.Operation) { if scope.Err() != nil { return } opspec := tf.OpSpec{ Type: "NoOp", } - scope.AddOperation(opspec) + return scope.AddOperation(opspec) } `, }, |