aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/go')
-rw-r--r--tensorflow/go/op/wrappers.go16
1 files changed, 15 insertions, 1 deletions
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 18d7425323..6c9bf1e714 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -5225,12 +5225,26 @@ func IsBoostedTreesEnsembleInitialized(scope *Scope, tree_ensemble_handle tf.Out
return op.Output(0)
}
+// CastAttr is an optional argument to Cast.
+type CastAttr func(optionalAttr)
+
+// CastTruncate sets the optional Truncate attribute to value.
+// If not specified, defaults to false
+func CastTruncate(value bool) CastAttr {
+ return func(m optionalAttr) {
+ m["Truncate"] = value
+ }
+}
+
// Cast x of type SrcT to y of DstT.
-func Cast(scope *Scope, x tf.Output, DstT tf.DataType) (y tf.Output) {
+func Cast(scope *Scope, x tf.Output, DstT tf.DataType, optional ...CastAttr) (y tf.Output) {
if scope.Err() != nil {
return
}
attrs := map[string]interface{}{"DstT": DstT}
+ for _, a := range optional {
+ a(attrs)
+ }
opspec := tf.OpSpec{
Type: "Cast",
Input: []tf.Input{