aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/genop
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2016-10-17 15:33:15 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-17 16:48:44 -0700
commit4187c410e552ea6a27a9a0ad6402ba5ab2dcb4bf (patch)
tree2bc1431c31911b23c0172b4cbc703fb29a1db1a1 /tensorflow/go/genop
parent66024fd508748d706b72d0ae5e8b07f917e78458 (diff)
go: Generate wrapper functions for ops
This change fills in the body of the "genop" command-line tool to generate Go wrapper functions for each TensorFlow op. The generated API is as follows: - One function generated per TensorFlow op - Arguments to the function are: Scope, all inputs, all required attributes and a variadic list of optional attributes - Outputs of the function are all outputs and an error - The optional attributes of each op get their own type (<OpName>Attr) - And there are a set of factory functions for that type, one for each optional attribute See genop_test.go for a sample of how OpDefs turn into generated source code. Operations with either: (a) Reference typed input or outputs (b) "func" valued attributes Are skipped for now. Have to work out the appropriate representation in Go for them. But a vast majority of operations do not use these, so we skip them. The "description" and "summary" fields in the OpDef are used to generate documentation comments. However, the OpDef style is to use markdown and so some of that markdown does not end well in Go documentation. Living with that for now. Another step towards #10 Change: 136412140
Diffstat (limited to 'tensorflow/go/genop')
-rw-r--r--tensorflow/go/genop/internal/genop.go399
-rw-r--r--tensorflow/go/genop/internal/genop_test.go339
2 files changed, 727 insertions, 11 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go
index be84b2322a..fdc55f5ebc 100644
--- a/tensorflow/go/genop/internal/genop.go
+++ b/tensorflow/go/genop/internal/genop.go
@@ -14,7 +14,17 @@
// Package internal generates Go source code with functions for TensorFlow operations.
//
-// The generated APIs are unstable and can change without notice.
+// The basic outline of the generated API is as follows:
+//
+// - One function for each TensorFlow operation
+// - The arguments to the function are the inputs and required attributes of the operation
+// - The function returns the outputs
+// - A function is also generated for each optional attribute of the operation.
+//
+// There is a possibility that there are name collisions between the functions
+// generated for ops and the functions generated for optional attributes. For
+// now, we ignore those, but will need to revisit if a collision is actually
+// encountered.
package internal
// #include "tensorflow/c/c_api.h"
@@ -23,6 +33,9 @@ import "C"
import (
"fmt"
"io"
+ "reflect"
+ "strings"
+ "text/template"
"unsafe"
"github.com/golang/protobuf/proto"
@@ -37,16 +50,7 @@ func GenerateFunctionsForRegisteredOps(w io.Writer) error {
if err != nil {
return err
}
- fmt.Fprintf(w, `// DO NOT EDIT
-// This file was machine generated.
-//
-// This code generation process is a work in progress and is not ready yet.
-// Eventually, the code generator will generate approximately %d wrapper
-// functions for adding TensorFlow operations to a Graph.
-
-package op
-`, len(ops.Op))
- return nil
+ return generateFunctionsForOps(w, ops)
}
func registeredOps() (*pb.OpList, error) {
@@ -62,3 +66,376 @@ func registeredOps() (*pb.OpList, error) {
)
return list, err
}
+
+func generateFunctionsForOps(w io.Writer, ops *pb.OpList) error {
+ thisPackage := reflect.TypeOf(tmplArgs{}).PkgPath()
+ if err := tmplHeader.Execute(w, thisPackage); err != nil {
+ return err
+ }
+ blacklist := map[string]bool{
+ "Const": true,
+ "PyFunc": true,
+ "PyFuncStateless": true,
+ }
+ for _, op := range ops.Op {
+ if blacklist[op.Name] {
+ continue
+ }
+ if err := generateFunctionForOp(w, op); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func generateFunctionForOp(w io.Writer, op *pb.OpDef) error {
+ if strings.HasPrefix(op.Name, "_") { // Internal operation
+ return nil
+ }
+ // Ignore operations where the Go types corresponding to the TensorFlow
+ // type haven't been worked out (such as "func"s).
+ for _, a := range op.Attr {
+ if _, err := goType(a.Type); err != nil {
+ return nil
+ }
+ }
+ // Also, haven't figured out reference types yet, so ignore those too.
+ for _, a := range op.InputArg {
+ if a.IsRef {
+ return nil
+ }
+ }
+ for _, a := range op.OutputArg {
+ if a.IsRef {
+ return nil
+ }
+ }
+ if op.Summary == "" {
+ // Undocumented operation, perhaps a sign of not being ready to
+ // export.
+ return nil
+ }
+ return tmplOp.Execute(w, newTmplArgs(op))
+}
+
+var (
+ // Go keywords that cannot be used as identifiers.
+ // From https://golang.org/ref/spec#Keywords
+ keywords = []string{
+ "break", "default", "func", "interface", "select", "case",
+ "defer", "go", "map", "struct", "chan", "else", "goto",
+ "package", "switch", "const", "fallthrough", "if", "range",
+ "type", "continue", "for", "import", "return", "var",
+ }
+
+ tmplHeader = template.Must(template.New("header").Parse(`// DO NOT EDIT
+// This file was machine generated by {{.}}
+//
+// WARNING: This generation of wrapper function for TensorFlow ops is in an
+// experimental state. The generated API can change without notice.
+
+package op
+
+import tf "github.com/tensorflow/tensorflow/tensorflow/go"
+
+// optionalAttr is an intentionally un-exported type to hide
+// details of how optional attributes to operations are implemented.
+type optionalAttr map[string]interface{}
+
+func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, int, error) {
+ size, err := op.OutputListSize(output)
+ if err != nil {
+ return nil, start, err
+ }
+ list := make([]tf.Output, size)
+ for i := 0; i < size; i++ {
+ list[i] = op.Output(start + i)
+ }
+ return list, start + size, nil
+}
+`))
+
+ tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
+ "MakeComment": makeComment,
+ "GoType": goType,
+ "CamelCase": camelCase,
+ "Identifier": identifier,
+ "IsListArg": isListArg,
+ "IsListAttr": isListAttr,
+ }).Parse(`
+{{if .OptionalAttrs -}}
+{{/* Type for specifying all optional attributes. */ -}}
+// {{.Op.Name}}Attr is an optional argument to {{.Op.Name}}.
+type {{.Op.Name}}Attr func(optionalAttr)
+
+{{range .OptionalAttrs}}
+// {{$.Op.Name}}{{CamelCase .Name}} sets the optional {{.Name}} attribute to value.
+{{- if .Description}}
+//
+// value: {{MakeComment .Description}}
+{{- end}}
+// If not specified, defaults to {{.DefaultValue}}
+{{- if .HasMinimum}}
+//
+// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
+{{- end}}
+func {{$.Op.Name}}{{CamelCase .Name}}(value {{GoType .Type}}) {{$.Op.Name}}Attr {
+ return func(m optionalAttr) {
+ m[{{printf "%q" .Name}}] = value
+ }
+}
+{{end}}
+{{end}}
+
+{{- /* Create a godoc friendly comment. */ -}}
+
+// {{MakeComment .Op.Summary}}
+
+{{- with .Op.Deprecation}}
+//
+// DEPRECATED at GraphDef version {{.Version}}: {{.Explanation}}
+{{- end -}}
+
+{{- with .Op.Description}}
+//
+// {{MakeComment .}}
+{{- end -}}
+
+{{- if .DescribeArguments}}
+//
+// Arguments:
+{{- range .Op.InputArg}}
+// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
+{{- end -}}
+{{- range .RequiredAttrs}}
+// {{if .Description}}{{Identifier .Name}}: {{MakeComment .Description}}{{end}}
+{{- end -}}
+{{- end -}}
+
+{{- if .DescribeOutputs}}
+//
+{{- if ((len .Op.OutputArg) eq 1) }}
+// Returns {{range .Op.OutputArg}}{{MakeComment .Description}}{{end}}
+{{- else }}
+// Returns:
+{{- range .Op.OutputArg}}
+// {{Identifier .Name}}{{if .Description}}: {{MakeComment .Description}}{{end}}
+{{- end -}}
+{{- end -}}
+{{- end -}}
+{{- /*
+
+ The function signature.
+ Since OpDef.Name is in CamelCase, it cannot conflict with a reserved keyword in Golang
+*/}}
+func {{.Op.Name}}
+
+{{- /*
+ Fill in input arguments:
+ (1) The Scope
+ (2) All input arguments (which may be either []tf.Output or tf.Output)
+ (3) All required attributes
+ (4) Variadic list of optional attributes
+*/ -}}
+
+(scope *Scope
+{{- range $i, $a := .Op.InputArg}}, {{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}
+{{range $i, $a := .RequiredAttrs}}, {{Identifier $a.Name}} {{GoType $a.Type}}{{end -}}
+{{if .OptionalAttrs}}, optional ...{{.Op.Name}}Attr{{end -}}
+)
+
+{{- /* Construct outputs: len(OpDef.OutputArg) + 1 (for error) */ -}}
+
+({{range $i,$a := .Op.OutputArg}}{{if $i}}, {{end}}{{Identifier $a.Name}} {{if IsListArg $a}}[]{{end}}tf.Output{{end -}}
+{{if .Op.OutputArg}}, {{end}}err error) {
+ {{if .HasAttrs -}}
+ attrs := map[string]interface{}{ {{- range .RequiredAttrs}}{{printf "%q" .Name}}: {{Identifier .Name}},{{end}}}
+ {{if .OptionalAttrs -}}
+ for _, a := range optional {
+ a(attrs)
+ }
+ {{end -}}
+ {{end -}}
+ opspec := tf.OpSpec{
+ Type: {{printf "%q" .Op.Name}},
+ {{if .Op.InputArg -}}
+ Input: []tf.Input{
+ {{range .Op.InputArg}}{{if IsListArg .}}tf.OutputList({{Identifier .Name}}){{else}}{{Identifier .Name}}{{end}}, {{end}}
+ },
+ {{end}}
+ {{- if .HasAttrs}}Attrs: attrs,{{end}}
+ }
+ {{if .Op.OutputArg}}op, err :={{else}}_, err ={{end}} scope.Graph().AddOperation(opspec)
+ {{- if .HasListOutput}}
+ var idx int
+ {{- range $i, $a := .Op.OutputArg}}
+ {{- if IsListArg $a}}
+ if {{Identifier .Name}}, idx, err = makeOutputList(op, idx, {{printf "%q" .Name}}); err != nil {
+ return {{range $.Op.OutputArg}}{{Identifier .Name}}, {{end}}err
+ }
+ {{- else }}
+ {{Identifier .Name}} = op.Output(idx)
+ {{- end }}
+ {{- end }}
+ return {{range .Op.OutputArg}}{{Identifier .Name}}, {{end}}err
+ {{- else }}
+ return {{range $i, $a := .Op.OutputArg}}op.Output({{$i}}), {{end}}err
+ {{- end }}
+}
+`))
+)
+
+type tmplArgs struct {
+ Op *pb.OpDef
+ // Op.Attr is split into two categories
+ // (1) Required: These must be specified by the client and are thus
+ // included in the function signature.
+ // (2) Optional: These need not be specified (as they have default
+ // values) and thus do not appear in the function signature.
+ RequiredAttrs []*pb.OpDef_AttrDef
+ OptionalAttrs []*pb.OpDef_AttrDef
+}
+
+func newTmplArgs(op *pb.OpDef) *tmplArgs {
+ ret := tmplArgs{Op: op}
+ if len(op.Attr) == 0 {
+ return &ret
+ }
+ // Attributes related to the InputArg's type are inferred automatically
+ // and are not exposed to the client.
+ inferred := make(map[string]bool)
+ for _, in := range op.InputArg {
+ switch {
+ case in.TypeAttr != "":
+ inferred[in.TypeAttr] = true
+ case in.TypeListAttr != "":
+ inferred[in.TypeListAttr] = true
+ }
+ if in.NumberAttr != "" {
+ inferred[in.NumberAttr] = true
+ }
+ }
+ for _, attr := range op.Attr {
+ if inferred[attr.Name] {
+ continue
+ }
+ if attr.DefaultValue == nil {
+ ret.RequiredAttrs = append(ret.RequiredAttrs, attr)
+ } else {
+ ret.OptionalAttrs = append(ret.OptionalAttrs, attr)
+ }
+ }
+ return &ret
+}
+
+func (a *tmplArgs) HasAttrs() bool { return len(a.RequiredAttrs)+len(a.OptionalAttrs) > 0 }
+func (a *tmplArgs) DescribeArguments() bool {
+ for _, arg := range a.Op.InputArg {
+ if arg.Description != "" {
+ return true
+ }
+ }
+ for _, attr := range a.RequiredAttrs {
+ if attr.Description != "" {
+ return true
+ }
+ }
+ return false
+
+}
+func (a *tmplArgs) DescribeOutputs() bool {
+ for _, arg := range a.Op.OutputArg {
+ if arg.Description != "" {
+ return true
+ }
+ }
+ return false
+}
+func (a *tmplArgs) HasListOutput() bool {
+ for _, arg := range a.Op.OutputArg {
+ if isListArg(arg) {
+ return true
+ }
+ }
+ return false
+}
+
+func makeComment(lines string) string {
+ return strings.Join(strings.SplitAfter(lines, "\n"), "// ")
+}
+
+// goType converts a TensorFlow "type" ('string', 'int', 'list(string)' etc.)
+// to the corresponding type in Go.
+func goType(tfType string) (string, error) {
+ list, tfType := parseTFType(tfType)
+ var gotype string
+ switch tfType {
+ case "int":
+ gotype = "int"
+ case "float":
+ gotype = "float32"
+ case "bool":
+ gotype = "bool"
+ case "type":
+ gotype = "tf.DataType"
+ case "shape":
+ gotype = "[]int64"
+ case "tensor":
+ gotype = "tf.Tensor"
+ case "string":
+ gotype = "string"
+ default:
+ return "", fmt.Errorf("%q is not a recognized DataType", tfType)
+ }
+ if list {
+ gotype = "[]" + gotype
+ }
+ return gotype, nil
+}
+
+func camelCase(snakeCase string) string {
+ words := strings.Split(snakeCase, "_")
+ for i, w := range words {
+ words[i] = strings.ToUpper(string(w[0])) + w[1:]
+ }
+ return strings.Join(words, "")
+}
+
+// identifier creates an identifier for s usable in the generated Go source
+// code.
+//
+// Avoids collisions with keywords and other identifiers used in the generated
+// code.
+func identifier(s string) string {
+ // Identifiers used in the generated code.
+ if s == "tf" || s == "scope" || s == "err" || s == "op" {
+ return s + "_"
+ }
+ for _, k := range keywords {
+ if s == k {
+ // Alternatively, make the first letter upper case.
+ return s + "_"
+ }
+ }
+ return s
+}
+
+func isListArg(argdef *pb.OpDef_ArgDef) bool {
+ return argdef.TypeListAttr != "" || argdef.NumberAttr != ""
+}
+
+func isListAttr(attrdef *pb.OpDef_AttrDef) bool {
+ list, _ := parseTFType(attrdef.Type)
+ return list
+}
+
+func parseTFType(tfType string) (list bool, typ string) {
+ const (
+ listPrefix = "list("
+ listSuffix = ")"
+ )
+ if strings.HasPrefix(tfType, listPrefix) && strings.HasSuffix(tfType, listSuffix) {
+ return true, strings.TrimSuffix(strings.TrimPrefix(tfType, listPrefix), listSuffix)
+ }
+ return false, tfType
+}
diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go
new file mode 100644
index 0000000000..dade7ce48f
--- /dev/null
+++ b/tensorflow/go/genop/internal/genop_test.go
@@ -0,0 +1,339 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package internal
+
+import (
+ "bytes"
+ "go/format"
+ "testing"
+
+ "github.com/golang/protobuf/proto"
+ pb "github.com/tensorflow/tensorflow/tensorflow/go/genop/internal/proto/tensorflow/core/framework"
+)
+
+func TestGenerateOp(t *testing.T) {
+ // TestGenerateOp validates the generated source code for an op.
+ // The OpDef for the test cases are simplified forms of real ops.
+ testdata := []struct {
+ tag string
+ opdef string
+ wanted string
+ }{
+ {
+ tag: "NoOp",
+ opdef: `
+name: "NoOp"
+summary: "No. Op."
+`,
+ wanted: `
+// No. Op.
+func NoOp(scope *Scope) (err error) {
+ opspec := tf.OpSpec{
+ Type: "NoOp",
+ }
+ _, err = scope.Graph().AddOperation(opspec)
+ return err
+}
+`,
+ },
+ {
+ tag: "NoAttributes",
+ opdef: `
+name: "Add"
+input_arg: <
+ name: "x"
+ type_attr: "T"
+>
+input_arg: <
+ name: "y"
+ type_attr: "T"
+>
+output_arg: <
+ name: "z"
+ type_attr: "T"
+>
+attr: <
+ name: "T"
+ type: "type"
+ allowed_values: <
+ list: <
+ type: DT_FLOAT
+ type: DT_INT64
+ >
+ >
+>
+summary: "Returns x + y element-wise."
+description: "Blah blah",
+`,
+ wanted: `
+// Returns x + y element-wise.
+//
+// Blah blah
+func Add(scope *Scope, x tf.Output, y tf.Output) (z tf.Output, err error) {
+ opspec := tf.OpSpec{
+ Type: "Add",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op, err := scope.Graph().AddOperation(opspec)
+ return op.Output(0), err
+}
+`,
+ },
+ {
+ tag: "RequiredAttributes",
+ opdef: `
+name: "Cast"
+input_arg: <
+ name: "x"
+ type_attr: "SrcT"
+>
+output_arg: <
+ name: "y"
+ type_attr: "DstT"
+>
+attr: <
+ name: "SrcT"
+ type: "type"
+>
+attr: <
+ name: "DstT"
+ type: "type"
+>
+summary: "Cast x of type SrcT to y of DstT."
+`,
+ wanted: `
+// Cast x of type SrcT to y of DstT.
+func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output, err error) {
+ attrs := map[string]interface{}{"DstT": DstT}
+ opspec := tf.OpSpec{
+ Type: "Cast",
+ Input: []tf.Input{
+ x,
+ },
+ Attrs: attrs,
+ }
+ op, err := scope.Graph().AddOperation(opspec)
+ return op.Output(0), err
+}
+`,
+ },
+ {
+ tag: "OptionalAttributes",
+ opdef: `
+name: "DecodeJpeg"
+input_arg: <
+ name: "contents"
+ description: "0-D. The JPEG-encoded image."
+ type: DT_STRING
+>
+output_arg: <
+ name: "image"
+ description: "3-D with shape [height, width, channels]"
+ type: DT_UINT8
+>
+attr: <
+ name: "channels"
+ type: "int"
+ default_value: <
+ i: 0
+ >
+ description: "Number of color channels for the decoded image."
+>
+attr: <
+ name: "fancy_upscaling"
+ type: "bool"
+ default_value: <
+ b: true
+ >
+ description: "If true use a slower but nicer upscaling of the\nchroma planes (yuv420/422 only)."
+>
+attr: <
+ name: "acceptable_fraction"
+ type: "float"
+ default_value: <
+ f: 1
+ >
+ description: "The minimum required fraction of lines before a truncated\ninput is accepted."
+>
+summary: "Decode a JPEG-encoded image to a uint8 tensor."
+description: "Norna dorna fjord\nkajorna\nhahaha"
+`,
+ wanted: `
+// DecodeJpegAttr is an optional argument to DecodeJpeg.
+type DecodeJpegAttr func(optionalAttr)
+
+// DecodeJpegChannels sets the optional channels attribute to value.
+//
+// value: Number of color channels for the decoded image.
+// If not specified, defaults to i:0
+func DecodeJpegChannels(value int) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["channels"] = value
+ }
+}
+
+// DecodeJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
+//
+// value: If true use a slower but nicer upscaling of the
+// chroma planes (yuv420/422 only).
+// If not specified, defaults to b:true
+func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["fancy_upscaling"] = value
+ }
+}
+
+// DecodeJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
+//
+// value: The minimum required fraction of lines before a truncated
+// input is accepted.
+// If not specified, defaults to f:1
+func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
+ return func(m optionalAttr) {
+ m["acceptable_fraction"] = value
+ }
+}
+
+// Decode a JPEG-encoded image to a uint8 tensor.
+//
+// Norna dorna fjord
+// kajorna
+// hahaha
+//
+// Arguments:
+// contents: 0-D. The JPEG-encoded image.
+//
+// Returns 3-D with shape [height, width, channels]
+func DecodeJpeg(scope *Scope, contents tf.Output, optional ...DecodeJpegAttr) (image tf.Output, err error) {
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeJpeg",
+ Input: []tf.Input{
+ contents,
+ },
+ Attrs: attrs,
+ }
+ op, err := scope.Graph().AddOperation(opspec)
+ return op.Output(0), err
+}
+`,
+ },
+ {
+ tag: "ListOutput",
+ opdef: `
+name: "ShapeN"
+input_arg: <
+ name: "input"
+ type_attr: "T"
+ number_attr: "N"
+>
+output_arg: <
+ name: "output"
+ type_attr: "out_type"
+ number_attr: "N"
+>
+attr: <
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+>
+attr: <
+ name: "T"
+ type: "type"
+>
+attr: <
+ name: "out_type"
+ type: "type"
+ default_value: <
+ type: DT_INT32
+ >
+ allowed_values: <
+ list: <
+ type: DT_INT32
+ type: DT_INT64
+ >
+ >
+>
+summary: "Returns shape of tensors."
+description: "Some description here."
+`,
+ wanted: `
+// ShapeNAttr is an optional argument to ShapeN.
+type ShapeNAttr func(optionalAttr)
+
+// ShapeNOutType sets the optional out_type attribute to value.
+// If not specified, defaults to type:DT_INT32
+func ShapeNOutType(value tf.DataType) ShapeNAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Returns shape of tensors.
+//
+// Some description here.
+func ShapeN(scope *Scope, input []tf.Output, optional ...ShapeNAttr) (output []tf.Output, err error) {
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ShapeN",
+ Input: []tf.Input{
+ tf.OutputList(input),
+ },
+ Attrs: attrs,
+ }
+ op, err := scope.Graph().AddOperation(opspec)
+ var idx int
+ if output, idx, err = makeOutputList(op, idx, "output"); err != nil {
+ return output, err
+ }
+ return output, err
+}
+`,
+ },
+ }
+
+ for _, test := range testdata {
+ t.Run(test.tag, func(t *testing.T) {
+ var opdef pb.OpDef
+ var buf bytes.Buffer
+ if err := proto.UnmarshalText(test.opdef, &opdef); err != nil {
+ t.Fatal(err)
+ }
+ if err := generateFunctionForOp(&buf, &opdef); err != nil {
+ t.Fatal(err)
+ }
+ got, err := format.Source(buf.Bytes())
+ if err != nil {
+ t.Fatal(err)
+ }
+ want, err := format.Source([]byte(test.wanted))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, want) {
+ t.Fatalf("Got:\n%s\nWant:\n%s\n", got, want)
+ }
+ })
+ }
+}