aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/genop
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-03-07 17:49:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-07 18:08:36 -0800
commitfbcb58b422f9bf2c2a23ac33d21b87daf5f31d49 (patch)
tree4fb0f3c33a78c6e3110161d3a3b97d99a250c504 /tensorflow/go/genop
parent8633b22f83d979d17ba30ba78a14fb72b109a86d (diff)
Go: Fix #7175
Change: 149490754
Diffstat (limited to 'tensorflow/go/genop')
-rw-r--r--tensorflow/go/genop/internal/genop.go20
-rw-r--r--tensorflow/go/genop/internal/genop_test.go8
2 files changed, 23 insertions, 5 deletions
diff --git a/tensorflow/go/genop/internal/genop.go b/tensorflow/go/genop/internal/genop.go
index 16e4d0e512..d17c1ca41d 100644
--- a/tensorflow/go/genop/internal/genop.go
+++ b/tensorflow/go/genop/internal/genop.go
@@ -162,6 +162,7 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in
"Identifier": identifier,
"IsListArg": isListArg,
"IsListAttr": isListAttr,
+ "StripLeadingColon": stripLeadingColon,
}).Parse(`
{{if .OptionalAttrs -}}
{{/* Type for specifying all optional attributes. */ -}}
@@ -174,7 +175,7 @@ type {{.Op.Name}}Attr func(optionalAttr)
//
// value: {{MakeComment .Description}}
{{- end}}
-// If not specified, defaults to {{.DefaultValue}}
+// If not specified, defaults to {{StripLeadingColon .DefaultValue}}
{{- if .HasMinimum}}
//
// {{if IsListAttr .}}REQUIRES: len(value) >= {{.Minimum}}{{else}}REQUIRES: value >= {{.Minimum}}{{end}}
@@ -452,6 +453,23 @@ func isListAttr(attrdef *pb.OpDef_AttrDef) bool {
return list
}
+// stripLeadingColon removes the prefix of the string up to the first colon.
+//
+// This is useful when 's' corresponds to a "oneof" protocol buffer message.
+// For example, consider the protocol buffer message:
+// oneof value { bool b = 1; int64 i = 2; }
+// String() on a Go corresponding object (using proto.CompactTextString) will
+// print "b:true", or "i:7" etc. This function strips out the leading "b:" or
+// "i:".
+func stripLeadingColon(s fmt.Stringer) string {
+ x := s.String()
+ y := strings.SplitN(x, ":", 2)
+ if len(y) < 2 {
+ return x
+ }
+ return y[1]
+}
+
func parseTFType(tfType string) (list bool, typ string) {
const (
listPrefix = "list("
diff --git a/tensorflow/go/genop/internal/genop_test.go b/tensorflow/go/genop/internal/genop_test.go
index c66e38fce0..00ac4827e4 100644
--- a/tensorflow/go/genop/internal/genop_test.go
+++ b/tensorflow/go/genop/internal/genop_test.go
@@ -189,7 +189,7 @@ type DecodeJpegAttr func(optionalAttr)
// DecodeJpegChannels sets the optional channels attribute to value.
//
// value: Number of color channels for the decoded image.
-// If not specified, defaults to i:0
+// If not specified, defaults to 0
func DecodeJpegChannels(value int64) DecodeJpegAttr {
return func(m optionalAttr) {
m["channels"] = value
@@ -200,7 +200,7 @@ func DecodeJpegChannels(value int64) DecodeJpegAttr {
//
// value: If true use a slower but nicer upscaling of the
// chroma planes (yuv420/422 only).
-// If not specified, defaults to b:true
+// If not specified, defaults to true
func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
return func(m optionalAttr) {
m["fancy_upscaling"] = value
@@ -211,7 +211,7 @@ func DecodeJpegFancyUpscaling(value bool) DecodeJpegAttr {
//
// value: The minimum required fraction of lines before a truncated
// input is accepted.
-// If not specified, defaults to f:1
+// If not specified, defaults to 1
func DecodeJpegAcceptableFraction(value float32) DecodeJpegAttr {
return func(m optionalAttr) {
m["acceptable_fraction"] = value
@@ -332,7 +332,7 @@ description: "Some description here."
type ShapeNAttr func(optionalAttr)
// ShapeNOutType sets the optional out_type attribute to value.
-// If not specified, defaults to type:DT_INT32
+// If not specified, defaults to DT_INT32
func ShapeNOutType(value tf.DataType) ShapeNAttr {
return func(m optionalAttr) {
m["out_type"] = value