aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/session.go
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-08-23 09:01:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 10:04:53 -0700
commit783c52edeb3c676937dbb97ed0d40958015050d6 (patch)
tree80c74954f68dad26a6e76a1c0edcb979d4d1804c /tensorflow/go/session.go
parent096069687c52e16eaa18c1db6e7bbf2737639257 (diff)
Initial version of the Go API. The API is subject to change.
Remaining work to do: - Generated ops. - Generated protocol buffers. - A few calls requiring protocol buffers aren't in this change. Change: 131066649
Diffstat (limited to 'tensorflow/go/session.go')
-rw-r--r--tensorflow/go/session.go187
1 files changed, 187 insertions, 0 deletions
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
new file mode 100644
index 0000000000..98a87602d1
--- /dev/null
+++ b/tensorflow/go/session.go
@@ -0,0 +1,187 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import (
+ "errors"
+ "runtime"
+ "sync"
+ "unsafe"
+)
+
+// Session drives a TensorFlow graph computation.
+//
+// When a Session is created with a given target, a new Session object is bound
+// to the universe of resources specified by that target. Those resources are
+// available to this session to perform computation described in the GraphDef.
+// After creating the session with a graph, the caller uses the Run() API to
+// perform the computation and potentially fetch outputs as Tensors.
+// A Session allows concurrent calls to Run().
+type Session struct {
+ c *C.TF_SessionWithGraph
+
+ // For ensuring that:
+ // - Close() blocks on all Run() calls to complete.
+ // - Close() can be called multiple times.
+ wg sync.WaitGroup
+ mu sync.Mutex
+}
+
+// NewSession creates a new execution session with the associated graph.
+// options may be nil to use the default options.
+func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
+ status := newStatus()
+ cOpt := options.c()
+ cSess := C.TF_NewSessionWithGraph(graph.c, cOpt, status.c)
+ C.TF_DeleteSessionOptions(cOpt)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ s := &Session{c: cSess}
+ runtime.SetFinalizer(s, func(s *Session) { s.Close() })
+ return s, nil
+}
+
+// Run the graph with the associated session starting with the supplied inputs.
+// inputs and outputs may be set to nil. Runs, but does not return Tensors
+// for operations specified in targets.
+//
+// On success, returns the Tensor outputs in the same order as supplied in
+// the outputs argument. If outputs is set to nil, the returned Tensor outputs
+// is empty.
+func (s *Session) Run(inputs map[Port]*Tensor, outputs []Port, targets []*Operation) ([]*Tensor, error) {
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
+ }
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
+
+ var inputPorts []C.TF_Port
+ var inputValues []*C.TF_Tensor
+ if inputs != nil {
+ for port, tensor := range inputs {
+ inputPorts = append(inputPorts, port.c())
+ inputValues = append(inputValues, tensor.c())
+ }
+ }
+
+ var outputPorts []C.TF_Port
+ for _, port := range outputs {
+ outputPorts = append(outputPorts, port.c())
+ }
+ outputValues := make([]*C.TF_Tensor, len(outputs))
+ var cTargets []*C.TF_Operation
+ for _, target := range targets {
+ cTargets = append(cTargets, target.c)
+ }
+
+ status := newStatus()
+ var inputPortsPtr *C.TF_Port
+ var inputValuesPtr **C.TF_Tensor
+ if len(inputPorts) > 0 {
+ inputPortsPtr = &inputPorts[0]
+ inputValuesPtr = &inputValues[0]
+ }
+
+ var outputPortsPtr *C.TF_Port
+ var outputValuesPtr **C.TF_Tensor
+ if len(outputPorts) > 0 {
+ outputPortsPtr = &outputPorts[0]
+ outputValuesPtr = &outputValues[0]
+ }
+
+ var cTargetsPtr **C.TF_Operation
+ if len(cTargets) > 0 {
+ cTargetsPtr = &cTargets[0]
+ }
+
+ C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ var tensors []*Tensor
+ for _, val := range outputValues {
+ tensors = append(tensors, newTensorFromC(val))
+ C.TF_DeleteTensor(val)
+ }
+
+ return tensors, nil
+}
+
+// Close a session. This contacts any other processes associated with this
+// session, if applicable. Blocks until all previous calls to Run have returned.
+func (s *Session) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.wg.Wait()
+ if s.c == nil {
+ return nil
+ }
+ status := newStatus()
+ C.TF_CloseSessionWithGraph(s.c, status.c)
+ if err := status.Err(); err != nil {
+ return err
+ }
+ C.TF_DeleteSessionWithGraph(s.c, status.c)
+ s.c = nil
+ return status.Err()
+}
+
+// SessionOptions contains configuration information for a session.
+type SessionOptions struct {
+ // Target indicates the TensorFlow runtime to connect to.
+ //
+ // If 'target' is empty or unspecified, the local TensorFlow runtime
+ // implementation will be used. Otherwise, the TensorFlow engine
+ // defined by 'target' will be used to perform all computations.
+ //
+ // "target" can be either a single entry or a comma separated list
+ // of entries. Each entry is a resolvable address of one of the
+ // following formats:
+ // local
+ // ip:port
+ // host:port
+ // ... other system-specific formats to identify tasks and jobs ...
+ //
+ // NOTE: at the moment 'local' maps to an in-process service-based
+ // runtime.
+ //
+ // Upon creation, a single session affines itself to one of the
+ // remote processes, with possible load balancing choices when the
+ // "target" resolves to a list of possible processes.
+ //
+ // If the session disconnects from the remote process during its
+ // lifetime, session calls may fail immediately.
+ Target string
+}
+
+// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
+// deallocate by calling C.TF_DeleteSessionOptions().
+func (o *SessionOptions) c() *C.TF_SessionOptions {
+ opt := C.TF_NewSessionOptions()
+ t := C.CString(o.Target)
+ C.TF_SetTarget(opt, t)
+ C.free(unsafe.Pointer(t))
+ return opt
+}