diff options
Diffstat (limited to 'tensorflow/go/op')
-rw-r--r-- | tensorflow/go/op/scope.go | 5 | ||||
-rw-r--r-- | tensorflow/go/op/scope_test.go | 15 | ||||
-rw-r--r-- | tensorflow/go/op/wrappers.go | 115 |
3 files changed, 63 insertions, 72 deletions
diff --git a/tensorflow/go/op/scope.go b/tensorflow/go/op/scope.go index d87833f451..a9ec79463a 100644 --- a/tensorflow/go/op/scope.go +++ b/tensorflow/go/op/scope.go @@ -49,6 +49,11 @@ func NewScope() *Scope { return &Scope{graph: tf.NewGraph(), namemap: make(map[string]int), err: new(scopeErr)} } +// NewScopeWithGraph creates a Scope initialized with the Graph thats passed in +func NewScopeWithGraph(g *tf.Graph) *Scope { + return &Scope{graph: g, namemap: make(map[string]int), err: new(scopeErr)} +} + // Finalize returns the Graph on which this scope operates on and renders s // unusable. If there was an error during graph construction, that error is // returned instead. diff --git a/tensorflow/go/op/scope_test.go b/tensorflow/go/op/scope_test.go index b74fd24b26..6fb5d32e50 100644 --- a/tensorflow/go/op/scope_test.go +++ b/tensorflow/go/op/scope_test.go @@ -95,6 +95,21 @@ func TestMultipleGeneratedOps(t *testing.T) { } } +func TestScopeWithGraph(t *testing.T) { + s1 := NewScope() + Const(s1, "hello") + graph, err := s1.Finalize() + if err != nil { + t.Fatal(err) + } + + s2 := NewScopeWithGraph(graph) + Const(s2.SubScope("addition"), "world") + if err := s2.Err(); err != nil { + t.Fatal(err) + } +} + func Example() { // This example creates a Graph that multiplies a constant matrix with // a matrix to be provided during graph execution (via diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c414255f93..9f048d3ea0 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -1337,6 +1337,47 @@ func PlaceholderV2(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.O return op.Output(0) } +// PlaceholderAttr is an optional argument to Placeholder. +type PlaceholderAttr func(optionalAttr) + +// PlaceholderShape sets the optional shape attribute to value. +// +// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the +// shape is unconstrained. +// If not specified, defaults to <unknown_rank:true > +func PlaceholderShape(value tf.Shape) PlaceholderAttr { + return func(m optionalAttr) { + m["shape"] = value + } +} + +// A placeholder op for a value that will be fed into the computation. +// +// N.B. This operation will fail with an error if it is executed. It is +// intended as a way to represent a value that will always be fed, and to +// provide attrs that enable the fed value to be checked at runtime. +// +// Arguments: +// dtype: The type of elements in the tensor. +// +// Returns A placeholder tensor that must be replaced using the feed mechanism. +func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"dtype": dtype} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Placeholder", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Pads a tensor with mirrored values. // // This operation pads a `input` with mirrored values according to the `paddings` @@ -4153,7 +4194,7 @@ func UnstageSharedName(value string) UnstageAttr { // Op is similar to a lightweight Dequeue. // -// The basic funtionality is similar to dequeue with many fewer +// The basic functionality is similar to dequeue with many fewer // capabilities and options. This Op is optimized for performance. func Unstage(scope *Scope, dtypes []tf.DataType, optional ...UnstageAttr) (values []tf.Output) { if scope.Err() != nil { @@ -4724,7 +4765,7 @@ type QueueCloseV2Attr func(optionalAttr) // QueueCloseV2CancelPendingEnqueues sets the optional cancel_pending_enqueues attribute to value. // // value: If true, all pending enqueue requests that are -// blocked on the given queue will be cancelled. +// blocked on the given queue will be canceled. // If not specified, defaults to false func QueueCloseV2CancelPendingEnqueues(value bool) QueueCloseV2Attr { return func(m optionalAttr) { @@ -4895,76 +4936,6 @@ func FixedLengthRecordDataset(scope *Scope, filenames tf.Output, header_bytes tf return op.Output(0) } -// PlaceholderAttr is an optional argument to Placeholder. -type PlaceholderAttr func(optionalAttr) - -// PlaceholderShape sets the optional shape attribute to value. -// -// value: (Optional) The shape of the tensor. If the shape has 0 dimensions, the -// shape is unconstrained. -// If not specified, defaults to <unknown_rank:true > -func PlaceholderShape(value tf.Shape) PlaceholderAttr { - return func(m optionalAttr) { - m["shape"] = value - } -} - -// A placeholder op for a value that will be fed into the computation. -// -// N.B. This operation will fail with an error if it is executed. It is -// intended as a way to represent a value that will always be fed, and to -// provide attrs that enable the fed value to be checked at runtime. -// -// Arguments: -// dtype: The type of elements in the tensor. -// -// Returns A placeholder tensor that must be replaced using the feed mechanism. -func Placeholder(scope *Scope, dtype tf.DataType, optional ...PlaceholderAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"dtype": dtype} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Placeholder", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Creates a dataset that caches elements from `input_dataset`. -// -// A CacheDataset will iterate over the input_dataset, and store tensors. If the -// cache already exists, the cache will be used. If the cache is inappropriate -// (e.g. cannot be opened, contains tensors of the wrong shape / size), an error -// will the returned when used. -// -// Arguments: -// -// filename: A path on the filesystem where we should cache the dataset. Note: this -// will be a directory. -// -// -func CacheDataset(scope *Scope, input_dataset tf.Output, filename tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "CacheDataset", - Input: []tf.Input{ - input_dataset, filename, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Deprecated. Use TensorArrayGradV3 func TensorArrayGradV2(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output) { if scope.Err() != nil { |