diff options
Diffstat (limited to 'tensorflow/go')
-rw-r--r-- | tensorflow/go/op/wrappers.go | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 18d7425323..6c9bf1e714 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -5225,12 +5225,26 @@ func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Out return op.Output(0) } +// CastAttr is an optional argument to Cast. +type CastAttr func(optionalAttr) + +// CastTruncate sets the optional Truncate attribute to value. +// If not specified, defaults to false +func CastTruncate(value bool) CastAttr { + return func(m optionalAttr) { + m["Truncate"] = value + } +} + // Cast x of type SrcT to y of DstT. -func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) { +func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) { if scope.Err() != nil { return } attrs := map[string]interface{}{"DstT": DstT} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ Type: "Cast", Input: []tf.Input{ |