diff options
-rw-r--r-- | tensorflow/go/op/wrappers.go | 46 |
1 files changed, 40 insertions, 6 deletions
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 9e3a4666b9..8dd2931703 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -20544,23 +20544,40 @@ func Requantize(scope *Scope, input tf.Output, input_min tf.Output, input_max tf return op.Output(0), op.Output(1), op.Output(2) } +// ArgMinAttr is an optional argument to ArgMin. +type ArgMinAttr func(optionalAttr) + +// ArgMinOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMinOutputType(value tf.DataType) ArgMinAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + // Returns the index with the smallest value across dimensions of a tensor. // // Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: // -// dimension: int32, 0 <= dimension < rank(input). Describes which dimension -// of the input Tensor to reduce across. For vectors, use dimension = 0. -func ArgMin(scope *Scope, input tf.Output, dimension tf.Output) (output tf.Output) { +// dimension: int32 or int64, 0 <= dimension < rank(input). Describes +// which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMin(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMinAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "ArgMin", Input: []tf.Input{ input, dimension, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) @@ -21818,23 +21835,40 @@ func IsFinite(scope *Scope, x tf.Output) (y tf.Output) { return op.Output(0) } +// ArgMaxAttr is an optional argument to ArgMax. +type ArgMaxAttr func(optionalAttr) + +// ArgMaxOutputType sets the optional output_type attribute to value. +// If not specified, defaults to DT_INT64 +func ArgMaxOutputType(value tf.DataType) ArgMaxAttr { + return func(m optionalAttr) { + m["output_type"] = value + } +} + // Returns the index with the largest value across dimensions of a tensor. // // Note that in case of ties the identity of the return value is not guaranteed. // // Arguments: // -// dimension: int32, 0 <= dimension < rank(input). Describes which dimension -// of the input Tensor to reduce across. For vectors, use dimension = 0. -func ArgMax(scope *Scope, input tf.Output, dimension tf.Output) (output tf.Output) { +// dimension: int32 or int64, 0 <= dimension < rank(input). Describes +// which dimension of the input Tensor to reduce across. For vectors, +// use dimension = 0. +func ArgMax(scope *Scope, input tf.Output, dimension tf.Output, optional ...ArgMaxAttr) (output tf.Output) { if scope.Err() != nil { return } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "ArgMax", Input: []tf.Input{ input, dimension, }, + Attrs: attrs, } op := scope.AddOperation(opspec) return op.Output(0) |