aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-08-23 09:01:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-23 10:04:53 -0700
commit783c52edeb3c676937dbb97ed0d40958015050d6 (patch)
tree80c74954f68dad26a6e76a1c0edcb979d4d1804c
parent096069687c52e16eaa18c1db6e7bbf2737639257 (diff)
Initial version of the Go API. The API is subject to change.
Remaining work to do: - Generated ops. - Generated protocol buffers. - A few calls requiring protocol buffers aren't in this change. Change: 131066649
-rw-r--r--tensorflow/BUILD3
-rw-r--r--tensorflow/core/framework/types.proto2
-rw-r--r--tensorflow/go/BUILD22
-rw-r--r--tensorflow/go/doc.go18
-rw-r--r--tensorflow/go/graph.go38
-rw-r--r--tensorflow/go/lib.go19
-rw-r--r--tensorflow/go/operation.go82
-rw-r--r--tensorflow/go/session.go187
-rw-r--r--tensorflow/go/session_test.go114
-rw-r--r--tensorflow/go/status.go65
-rw-r--r--tensorflow/go/tensor.go259
-rw-r--r--tensorflow/go/tensor_test.go97
12 files changed, 906 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 2ab4f5a3a9..5fcdb7aa56 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -130,6 +130,7 @@ filegroup(
"//tensorflow/examples/tutorials/word2vec:all_files",
"//tensorflow/g3doc/how_tos/adding_an_op:all_files",
"//tensorflow/g3doc/tutorials:all_files",
+ "//tensorflow/go:all_files",
"//tensorflow/models/embedding:all_files",
"//tensorflow/models/image/alexnet:all_files",
"//tensorflow/models/image/cifar10:all_files",
@@ -167,6 +168,7 @@ cc_binary(
name = "libtensorflow.so",
linkshared = 1,
deps = [
+ "//tensorflow/c:c_api",
"//tensorflow/core:tensorflow",
],
)
@@ -175,6 +177,7 @@ cc_binary(
name = "libtensorflow_cc.so",
linkshared = 1,
deps = [
+ "//tensorflow/c:c_api",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:tensorflow",
],
diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto
index 051361bbed..c744594a49 100644
--- a/tensorflow/core/framework/types.proto
+++ b/tensorflow/core/framework/types.proto
@@ -6,6 +6,7 @@ option java_outer_classname = "TypesProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
+// LINT.IfChange
enum DataType {
// Not a legal value for DataType. Used to indicate a DataType field
// has not been set.
@@ -58,3 +59,4 @@ enum DataType {
DT_COMPLEX128_REF = 118;
DT_HALF_REF = 119;
}
+// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go)
diff --git a/tensorflow/go/BUILD b/tensorflow/go/BUILD
new file mode 100644
index 0000000000..d69233f4fe
--- /dev/null
+++ b/tensorflow/go/BUILD
@@ -0,0 +1,22 @@
+# Description:
+# Go API for TensorFlow.
+
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/go/doc.go b/tensorflow/go/doc.go
new file mode 100644
index 0000000000..4494d49556
--- /dev/null
+++ b/tensorflow/go/doc.go
@@ -0,0 +1,18 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tensorflow is a Go binding to TensorFlow.
+//
+// The API is subject to change and may break at any time.
+package tensorflow
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
new file mode 100644
index 0000000000..3e43a39817
--- /dev/null
+++ b/tensorflow/go/graph.go
@@ -0,0 +1,38 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import (
+ "runtime"
+)
+
+// Graph represents a computation graph. Graphs may be shared between sessions.
+type Graph struct {
+ c *C.TF_Graph
+}
+
+// NewGraph returns a new Graph.
+func NewGraph() *Graph {
+ g := &Graph{C.TF_NewGraph()}
+ runtime.SetFinalizer(g, (*Graph).finalizer)
+ return g
+}
+
+func (g *Graph) finalizer() {
+ C.TF_DeleteGraph(g.c)
+}
diff --git a/tensorflow/go/lib.go b/tensorflow/go/lib.go
new file mode 100644
index 0000000000..dcab7a90f8
--- /dev/null
+++ b/tensorflow/go/lib.go
@@ -0,0 +1,19 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #cgo LDFLAGS: -ltensorflow
+// #cgo CFLAGS: -I${SRCDIR}/../../
+import "C"
diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go
new file mode 100644
index 0000000000..19eb1a5816
--- /dev/null
+++ b/tensorflow/go/operation.go
@@ -0,0 +1,82 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+import "unsafe"
+
+// Operation that has been added to the graph.
+type Operation struct {
+ c *C.TF_Operation
+}
+
+// Port represents a specific input or output of an operation, e.g. to specify
+// the specific output to pass as an input to a new op.
+//
+// Note the difference in naming convention: Port corresponds to Tensor/Output
+// in the Python API.
+type Port struct {
+ Op *Operation
+ Index int
+}
+
+func (p *Port) c() C.TF_Port {
+ return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)}
+}
+
+// opBuilder is for use by the generated op code to create new Operations.
+// Build() must be called for any in-progress Operation, or else we leak.
+type opBuilder struct {
+ c *C.TF_OperationDescription
+}
+
+func newOpBuilder(g *Graph, typ string, name string) *opBuilder {
+ opType := C.CString(typ)
+ opName := C.CString(name)
+ b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName)}
+ C.free(unsafe.Pointer(opType))
+ C.free(unsafe.Pointer(opName))
+ return b
+}
+
+func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error {
+ status := newStatus()
+ attrName := C.CString(name)
+ C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c)
+ C.free(unsafe.Pointer(attrName))
+ return status.Err()
+}
+
+func (b *opBuilder) SetAttrType(name string, typ DataType) {
+ attrName := C.CString(name)
+ C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ))
+ C.free(unsafe.Pointer(attrName))
+}
+
+func (b *opBuilder) AddInput(port Port) {
+ C.TF_AddInput(b.c, port.c())
+}
+
+func (b *opBuilder) Build() (*Operation, error) {
+ status := newStatus()
+ op := &Operation{c: C.TF_FinishOperation(b.c, status.c)}
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+ b.c = nil
+ return op, nil
+}
diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go
new file mode 100644
index 0000000000..98a87602d1
--- /dev/null
+++ b/tensorflow/go/session.go
@@ -0,0 +1,187 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include <stdlib.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import (
+ "errors"
+ "runtime"
+ "sync"
+ "unsafe"
+)
+
+// Session drives a TensorFlow graph computation.
+//
+// When a Session is created with a given target, a new Session object is bound
+// to the universe of resources specified by that target. Those resources are
+// available to this session to perform computation described in the GraphDef.
+// After creating the session with a graph, the caller uses the Run() API to
+// perform the computation and potentially fetch outputs as Tensors.
+// A Session allows concurrent calls to Run().
+type Session struct {
+ c *C.TF_SessionWithGraph
+
+ // For ensuring that:
+ // - Close() blocks on all Run() calls to complete.
+ // - Close() can be called multiple times.
+ wg sync.WaitGroup
+ mu sync.Mutex
+}
+
+// NewSession creates a new execution session with the associated graph.
+// options may be nil to use the default options.
+func NewSession(graph *Graph, options *SessionOptions) (*Session, error) {
+ status := newStatus()
+ cOpt := options.c()
+ cSess := C.TF_NewSessionWithGraph(graph.c, cOpt, status.c)
+ C.TF_DeleteSessionOptions(cOpt)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ s := &Session{c: cSess}
+ runtime.SetFinalizer(s, func(s *Session) { s.Close() })
+ return s, nil
+}
+
+// Run the graph with the associated session starting with the supplied inputs.
+// inputs and outputs may be set to nil. Runs, but does not return Tensors
+// for operations specified in targets.
+//
+// On success, returns the Tensor outputs in the same order as supplied in
+// the outputs argument. If outputs is set to nil, the returned Tensor outputs
+// is empty.
+func (s *Session) Run(inputs map[Port]*Tensor, outputs []Port, targets []*Operation) ([]*Tensor, error) {
+ s.mu.Lock()
+ if s.c == nil {
+ s.mu.Unlock()
+ return nil, errors.New("session is closed")
+ }
+ s.wg.Add(1)
+ s.mu.Unlock()
+ defer s.wg.Done()
+
+ var inputPorts []C.TF_Port
+ var inputValues []*C.TF_Tensor
+ if inputs != nil {
+ for port, tensor := range inputs {
+ inputPorts = append(inputPorts, port.c())
+ inputValues = append(inputValues, tensor.c())
+ }
+ }
+
+ var outputPorts []C.TF_Port
+ for _, port := range outputs {
+ outputPorts = append(outputPorts, port.c())
+ }
+ outputValues := make([]*C.TF_Tensor, len(outputs))
+ var cTargets []*C.TF_Operation
+ for _, target := range targets {
+ cTargets = append(cTargets, target.c)
+ }
+
+ status := newStatus()
+ var inputPortsPtr *C.TF_Port
+ var inputValuesPtr **C.TF_Tensor
+ if len(inputPorts) > 0 {
+ inputPortsPtr = &inputPorts[0]
+ inputValuesPtr = &inputValues[0]
+ }
+
+ var outputPortsPtr *C.TF_Port
+ var outputValuesPtr **C.TF_Tensor
+ if len(outputPorts) > 0 {
+ outputPortsPtr = &outputPorts[0]
+ outputValuesPtr = &outputValues[0]
+ }
+
+ var cTargetsPtr **C.TF_Operation
+ if len(cTargets) > 0 {
+ cTargetsPtr = &cTargets[0]
+ }
+
+ C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c)
+ if err := status.Err(); err != nil {
+ return nil, err
+ }
+
+ var tensors []*Tensor
+ for _, val := range outputValues {
+ tensors = append(tensors, newTensorFromC(val))
+ C.TF_DeleteTensor(val)
+ }
+
+ return tensors, nil
+}
+
+// Close a session. This contacts any other processes associated with this
+// session, if applicable. Blocks until all previous calls to Run have returned.
+func (s *Session) Close() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.wg.Wait()
+ if s.c == nil {
+ return nil
+ }
+ status := newStatus()
+ C.TF_CloseSessionWithGraph(s.c, status.c)
+ if err := status.Err(); err != nil {
+ return err
+ }
+ C.TF_DeleteSessionWithGraph(s.c, status.c)
+ s.c = nil
+ return status.Err()
+}
+
+// SessionOptions contains configuration information for a session.
+type SessionOptions struct {
+ // Target indicates the TensorFlow runtime to connect to.
+ //
+ // If 'target' is empty or unspecified, the local TensorFlow runtime
+ // implementation will be used. Otherwise, the TensorFlow engine
+ // defined by 'target' will be used to perform all computations.
+ //
+ // "target" can be either a single entry or a comma separated list
+ // of entries. Each entry is a resolvable address of one of the
+ // following formats:
+ // local
+ // ip:port
+ // host:port
+ // ... other system-specific formats to identify tasks and jobs ...
+ //
+ // NOTE: at the moment 'local' maps to an in-process service-based
+ // runtime.
+ //
+ // Upon creation, a single session affines itself to one of the
+ // remote processes, with possible load balancing choices when the
+ // "target" resolves to a list of possible processes.
+ //
+ // If the session disconnects from the remote process during its
+ // lifetime, session calls may fail immediately.
+ Target string
+}
+
+// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must
+// deallocate by calling C.TF_DeleteSessionOptions().
+func (o *SessionOptions) c() *C.TF_SessionOptions {
+ opt := C.TF_NewSessionOptions()
+ t := C.CString(o.Target)
+ C.TF_SetTarget(opt, t)
+ C.free(unsafe.Pointer(t))
+ return opt
+}
diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go
new file mode 100644
index 0000000000..78f6bccfd6
--- /dev/null
+++ b/tensorflow/go/session_test.go
@@ -0,0 +1,114 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "reflect"
+ "testing"
+)
+
+func Placeholder(g *Graph, name string, dt DataType) (Port, error) {
+ b := newOpBuilder(g, "Placeholder", name)
+ b.SetAttrType("dtype", dt)
+ op, err := b.Build()
+ if err != nil {
+ return Port{}, err
+ }
+ return Port{op, 0}, nil
+}
+
+func Neg(g *Graph, name string, port Port) (Port, error) {
+ b := newOpBuilder(g, "Neg", name)
+ b.AddInput(port)
+ op, err := b.Build()
+ if err != nil {
+ return Port{}, err
+ }
+ return Port{op, 0}, nil
+}
+
+func createTestGraph(t *testing.T, dt DataType) (*Graph, Port, Port) {
+ g := NewGraph()
+ inp, err := Placeholder(g, "p1", dt)
+ if err != nil {
+ t.Fatalf("Placeholder() for %v: %v", dt, err)
+ }
+ out, err := Neg(g, "neg1", inp)
+ if err != nil {
+ t.Fatalf("Neg() for %v: %v", dt, err)
+ }
+ return g, inp, out
+}
+
+func TestSessionRunNeg(t *testing.T) {
+ var tests = []struct {
+ input interface{}
+ expected interface{}
+ }{
+ {int64(1), int64(-1)},
+ {[]float64{-1, -2, 3}, []float64{1, 2, -3}},
+ {[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}},
+ }
+
+ for _, test := range tests {
+ t1, err := NewTensor(test.input)
+ if err != nil {
+ t.Fatalf("NewTensor(%v): %v", test.input, err)
+ }
+ graph, inp, out := createTestGraph(t, t1.DataType())
+ s, err := NewSession(graph, &SessionOptions{})
+ if err != nil {
+ t.Fatalf("NewSession() for %v: %v", test.input, err)
+ }
+ output, err := s.Run(map[Port]*Tensor{inp: t1}, []Port{out}, []*Operation{out.Op})
+ if err != nil {
+ t.Fatalf("Run() for %v: %v", test.input, err)
+ }
+ if len(output) != 1 {
+ t.Errorf("%v: got %d outputs, want 1", test.input, len(output))
+ continue
+ }
+ val := output[0].Value()
+ if !reflect.DeepEqual(test.expected, val) {
+ t.Errorf("got %v, want %v", val, test.expected)
+ }
+ if err := s.Close(); err != nil {
+ t.Errorf("Close(): %v", err)
+ }
+ }
+}
+
+func TestConcurrency(t *testing.T) {
+ tensor, err := NewTensor(int64(1))
+ if err != nil {
+ t.Fatalf("NewTensor(): %v", err)
+ }
+
+ graph, inp, out := createTestGraph(t, tensor.DataType())
+ s, err := NewSession(graph, &SessionOptions{})
+ if err != nil {
+ t.Fatalf("NewSession(): %v", err)
+ }
+ for i := 0; i < 100; i++ {
+ // Session may close before Run() starts, so we don't check the error.
+ go s.Run(map[Port]*Tensor{inp: tensor}, []Port{out}, []*Operation{out.Op})
+ }
+ if err = s.Close(); err != nil {
+ t.Errorf("Close() 1: %v", err)
+ }
+ if err = s.Close(); err != nil {
+ t.Errorf("Close() 2: %v", err)
+ }
+}
diff --git a/tensorflow/go/status.go b/tensorflow/go/status.go
new file mode 100644
index 0000000000..a1f7ed5481
--- /dev/null
+++ b/tensorflow/go/status.go
@@ -0,0 +1,65 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import "runtime"
+
+type code C.TF_Code
+
+// status holds error information returned by TensorFlow. We convert all
+// TF statuses to Go errors.
+type status struct {
+ c *C.TF_Status
+}
+
+func newStatus() *status {
+ s := &status{C.TF_NewStatus()}
+ runtime.SetFinalizer(s, (*status).finalizer)
+ return s
+}
+
+func (s *status) finalizer() {
+ C.TF_DeleteStatus(s.c)
+}
+
+func (s *status) Code() code {
+ return code(C.TF_GetCode(s.c))
+}
+
+func (s *status) String() string {
+ return C.GoString(C.TF_Message(s.c))
+}
+
+// Err converts the status to a Go error and returns nil if the status is OK.
+func (s *status) Err() error {
+ if s == nil || s.Code() == C.TF_OK {
+ return nil
+ }
+ return (*statusError)(s)
+}
+
+// statusError is distinct from status because it fulfills the error interface.
+// status itself may have a TF_OK code and is not always considered an error.
+//
+// TODO(jhseu): Make public, rename to Error, and provide a way for users to
+// check status codes.
+type statusError status
+
+func (s *statusError) Error() string {
+ return (*status)(s).String()
+}
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
new file mode 100644
index 0000000000..e364e80f86
--- /dev/null
+++ b/tensorflow/go/tensor.go
@@ -0,0 +1,259 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+// #include <string.h>
+// #include "tensorflow/c/c_api.h"
+import "C"
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "reflect"
+ "unsafe"
+)
+
+// DataType holds the type for a scalar value. E.g., one slot in a tensor.
+// The values here are identical to corresponding values in types.proto.
+type DataType C.TF_DataType
+
+// Tensor holds a multi-dimensional array of elements of a single data type.
+type Tensor struct {
+ // We create TF_Tensor on demand rather than keep a handle to C.TF_Tensor
+ // because many functions, such as Session.Run() and Operations take
+ // ownership of the C.TF_Tensor. Translating on-demand provides for a safe
+ // API.
+ //
+ // A memcpy is required because cgo rules prohibit us from maintaining
+ // a pointer to Go memory.
+ // call: https://golang.org/cmd/cgo/
+ buf *bytes.Buffer
+ dt DataType
+ shape []int64
+}
+
+// NewTensor converts from a Go value to a Tensor. Valid values are scalars,
+// slices, and arrays. Every element of a slice must have the same length so
+// that the resulting Tensor has a valid shape.
+func NewTensor(value interface{}) (*Tensor, error) {
+ val := reflect.ValueOf(value)
+ rank, dataType, err := rankAndDataTypeOf(val.Type())
+ if err != nil {
+ return nil, err
+ }
+ t := &Tensor{buf: bytes.NewBuffer(nil), dt: dataType, shape: make([]int64, rank)}
+ if err = encodeTensor(t.buf, t.shape, val); err != nil {
+ return nil, err
+ }
+ return t, nil
+}
+
+// newTensorFromC converts from a C.TF_Tensor to a Tensor.
+func newTensorFromC(ct *C.TF_Tensor) *Tensor {
+ t := &Tensor{dt: DataType(C.TF_TensorType(ct))}
+ numDims := int(C.TF_NumDims(ct))
+ for i := 0; i < numDims; i++ {
+ t.shape = append(t.shape, int64(C.TF_Dim(ct, C.int(i))))
+ }
+ b := make([]byte, int(C.TF_TensorByteSize(ct)))
+ if len(b) > 0 {
+ C.memcpy(unsafe.Pointer(&b[0]), C.TF_TensorData(ct), C.size_t(len(b)))
+ }
+ t.buf = bytes.NewBuffer(b)
+ return t
+}
+
+// DataType returns the scalar datatype of the Tensor.
+func (t *Tensor) DataType() DataType {
+ return t.dt
+}
+
+// Shape returns the shape of the Tensor.
+func (t *Tensor) Shape() []int64 {
+ return t.shape
+}
+
+// Value converts the Tensor to a Go value. For now, not all Tensor types are
+// supported, and this function may panic if it encounters an unsupported
+// DataType.
+//
+// The type of the output depends on the Tensor type and rank. For example:
+// Tensor(int64, 0): int64
+// Tensor(float64, 3): [][][]float64
+func (t *Tensor) Value() interface{} {
+ typ, err := typeOf(t.DataType(), t.Shape())
+ if err != nil {
+ panic(err)
+ }
+ val := reflect.New(typ)
+ if err := decodeTensor(t.buf, t.Shape(), typ, val); err != nil {
+ panic(err)
+ }
+ return reflect.Indirect(val).Interface()
+}
+
+// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of
+// the *C.TF_Tensor, either by passing ownership to the C API or explicitly
+// calling C.TF_DeleteTensor() on it.
+func (t *Tensor) c() *C.TF_Tensor {
+ var shapePtr *C.int64_t
+ if len(t.shape) > 0 {
+ shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0]))
+ }
+ tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len()))
+ if t.buf.Len() > 0 {
+ slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210
+ C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len()))
+ }
+ return tensor
+}
+
+// deleteCTensor only exists to delete C.TF_Tensors in tests. go test doesn't
+// support cgo.
+func deleteCTensor(ct *C.TF_Tensor) {
+ C.TF_DeleteTensor(ct)
+}
+
+var types = []struct {
+ typ reflect.Type
+ dataType C.TF_DataType
+}{
+ {reflect.TypeOf(float32(0)), C.TF_FLOAT},
+ {reflect.TypeOf(float64(0)), C.TF_DOUBLE},
+ {reflect.TypeOf(int32(0)), C.TF_INT32},
+ {reflect.TypeOf(uint8(0)), C.TF_UINT8},
+ {reflect.TypeOf(int16(0)), C.TF_INT16},
+ {reflect.TypeOf(int8(0)), C.TF_INT8},
+ {reflect.TypeOf(""), C.TF_STRING},
+ {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64},
+ {reflect.TypeOf(int64(0)), C.TF_INT64},
+ {reflect.TypeOf(false), C.TF_BOOL},
+ {reflect.TypeOf(uint16(0)), C.TF_UINT16},
+ {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128},
+}
+
+// rankAndDataTypeOf returns the data type and rank of a Go type for use when
+// encoding. We fetch them separately from encoding to support 0-sized
+// dimensions.
+func rankAndDataTypeOf(typ reflect.Type) (int, DataType, error) {
+ rank := 0
+ elem := typ
+ for ; elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice; elem = elem.Elem() {
+ rank++
+ }
+ for _, t := range types {
+ if elem.Kind() == t.typ.Kind() {
+ return rank, DataType(t.dataType), nil
+ }
+ }
+ return 0, DataType(0), fmt.Errorf("unsupported type %v", typ)
+}
+
+// typeOf converts from a DataType and Shape to the equivalent Go type.
+func typeOf(dt DataType, shape []int64) (reflect.Type, error) {
+ var ret reflect.Type
+ for _, t := range types {
+ if dt == DataType(t.dataType) {
+ ret = t.typ
+ break
+ }
+ }
+ if ret == nil {
+ return nil, fmt.Errorf("DataType %v unsupported", dt)
+ }
+ for _ = range shape {
+ ret = reflect.SliceOf(ret)
+ }
+ return ret, nil
+}
+
+// encodeTensor writes v to the specified buffer using the format specified in
+// c_api.h
+func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error {
+ switch v.Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ if err := binary.Write(buf, nativeEndian, v.Interface()); err != nil {
+ return err
+ }
+
+ case reflect.Array, reflect.Slice:
+ // If slice elements are slices, verify that all of them have the same size.
+ // Go's type system makes that guarantee for arrays.
+ if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice {
+ expected := v.Index(0).Len()
+ for i := 1; i < v.Len(); i++ {
+ if v.Index(i).Len() != expected {
+ return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected)
+ }
+ }
+ }
+
+ shape[0] = int64(v.Len())
+ for i := 0; i < v.Len(); i++ {
+ err := encodeTensor(buf, shape[1:], v.Index(i))
+ if err != nil {
+ return err
+ }
+ }
+
+ default:
+ return fmt.Errorf("unsupported type %v", v.Type())
+ }
+ return nil
+}
+
+// decodeTensor decodes the Tensor from the buffer to ptr using the format
+// specified in c_api.h
+func decodeTensor(buf *bytes.Buffer, shape []int64, typ reflect.Type, ptr reflect.Value) error {
+ switch typ.Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ if err := binary.Read(buf, nativeEndian, ptr.Interface()); err != nil {
+ return err
+ }
+
+ case reflect.Slice:
+ val := reflect.Indirect(ptr)
+ val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
+ for i := 0; i < val.Len(); i++ {
+ if err := decodeTensor(buf, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
+ return err
+ }
+ }
+
+ default:
+ return fmt.Errorf("unsupported type %v", typ)
+ }
+ return nil
+}
+
+// nativeEndian is the byte order for the local platform. Used to send back and
+// forth Tensors with the C API. We test for endianness at runtime because
+// some architectures can be booted into different endian modes.
+var nativeEndian binary.ByteOrder
+
+func init() {
+ buf := [2]byte{}
+ *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
+
+ switch buf {
+ case [2]byte{0xCD, 0xAB}:
+ nativeEndian = binary.LittleEndian
+ case [2]byte{0xAB, 0xCD}:
+ nativeEndian = binary.BigEndian
+ default:
+ panic("Could not determine native endianness.")
+ }
+}
diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go
new file mode 100644
index 0000000000..630d613729
--- /dev/null
+++ b/tensorflow/go/tensor_test.go
@@ -0,0 +1,97 @@
+// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tensorflow
+
+import (
+ "reflect"
+ "testing"
+)
+
+func TestNewTensor(t *testing.T) {
+ var tests = []struct {
+ shape []int64
+ value interface{}
+ }{
+ {[]int64{}, int8(5)},
+ {[]int64{}, int16(5)},
+ {[]int64{}, int32(5)},
+ {[]int64{}, int64(5)},
+ {[]int64{}, int64(5)},
+ {[]int64{}, uint8(5)},
+ {[]int64{}, uint16(5)},
+ {[]int64{}, float32(5)},
+ {[]int64{}, float64(5)},
+ {[]int64{}, complex(float32(5), float32(6))},
+ {[]int64{}, complex(float64(5), float64(6))},
+ {[]int64{1}, []float64{1}},
+ {[]int64{1}, [1]float64{1}},
+ {[]int64{3, 2}, [][]float64{{1, 2}, {3, 4}, {5, 6}}},
+ {[]int64{2, 3}, [2][3]float64{{1, 2, 3}, {3, 4, 6}}},
+ {[]int64{4, 3, 2}, [][][]float64{
+ {{1, 2}, {3, 4}, {5, 6}},
+ {{7, 8}, {9, 10}, {11, 12}},
+ {{0, -1}, {-2, -3}, {-4, -5}},
+ {{-6, -7}, {-8, -9}, {-10, -11}},
+ }},
+ {[]int64{2, 0}, [][]int64{{}, {}}},
+ }
+
+ var errorTests = []interface{}{
+ struct{ a int }{5},
+ new(int32),
+ new([]int32),
+ // native ints not supported
+ int(5),
+ []int{5},
+ // uint32 and uint64 are not supported in TensorFlow
+ uint32(5),
+ []uint32{5},
+ uint64(5),
+ []uint64{5},
+ }
+
+ for _, test := range tests {
+ tensor, err := NewTensor(test.value)
+ if err != nil {
+ t.Errorf("NewTensor(%v): %v", test.value, err)
+ continue
+ }
+ if !reflect.DeepEqual(test.shape, tensor.Shape()) {
+ t.Errorf("Tensor.Shape(): got %v, want %v", tensor.Shape(), test.shape)
+ }
+
+ // Test that encode and decode gives the same value. We skip arrays because
+ // they're returned as slices.
+ if reflect.TypeOf(test.value).Kind() != reflect.Array {
+ cTensor := tensor.c()
+ gotTensor := newTensorFromC(cTensor)
+ deleteCTensor(cTensor)
+ got := gotTensor.Value()
+ if !reflect.DeepEqual(test.value, got) {
+ t.Errorf("encode/decode: got %v, want %v", got, test.value)
+ }
+ }
+ }
+
+ for _, test := range errorTests {
+ tensor, err := NewTensor(test)
+ if err == nil {
+ t.Errorf("NewTensor(%v): %v", test, err)
+ }
+ if tensor != nil {
+ t.Errorf("NewTensor(%v) = %v, want nil", test, tensor)
+ }
+ }
+}