diff options
author | Jonathan Hseu <jhseu@google.com> | 2016-08-23 09:01:25 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-23 10:04:53 -0700 |
commit | 783c52edeb3c676937dbb97ed0d40958015050d6 (patch) | |
tree | 80c74954f68dad26a6e76a1c0edcb979d4d1804c | |
parent | 096069687c52e16eaa18c1db6e7bbf2737639257 (diff) |
Initial version of the Go API. The API is subject to change.
Remaining work to do:
- Generated ops.
- Generated protocol buffers.
- A few calls requiring protocol buffers aren't in this change.
Change: 131066649
-rw-r--r-- | tensorflow/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/core/framework/types.proto | 2 | ||||
-rw-r--r-- | tensorflow/go/BUILD | 22 | ||||
-rw-r--r-- | tensorflow/go/doc.go | 18 | ||||
-rw-r--r-- | tensorflow/go/graph.go | 38 | ||||
-rw-r--r-- | tensorflow/go/lib.go | 19 | ||||
-rw-r--r-- | tensorflow/go/operation.go | 82 | ||||
-rw-r--r-- | tensorflow/go/session.go | 187 | ||||
-rw-r--r-- | tensorflow/go/session_test.go | 114 | ||||
-rw-r--r-- | tensorflow/go/status.go | 65 | ||||
-rw-r--r-- | tensorflow/go/tensor.go | 259 | ||||
-rw-r--r-- | tensorflow/go/tensor_test.go | 97 |
12 files changed, 906 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 2ab4f5a3a9..5fcdb7aa56 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -130,6 +130,7 @@ filegroup( "//tensorflow/examples/tutorials/word2vec:all_files", "//tensorflow/g3doc/how_tos/adding_an_op:all_files", "//tensorflow/g3doc/tutorials:all_files", + "//tensorflow/go:all_files", "//tensorflow/models/embedding:all_files", "//tensorflow/models/image/alexnet:all_files", "//tensorflow/models/image/cifar10:all_files", @@ -167,6 +168,7 @@ cc_binary( name = "libtensorflow.so", linkshared = 1, deps = [ + "//tensorflow/c:c_api", "//tensorflow/core:tensorflow", ], ) @@ -175,6 +177,7 @@ cc_binary( name = "libtensorflow_cc.so", linkshared = 1, deps = [ + "//tensorflow/c:c_api", "//tensorflow/cc:cc_ops", "//tensorflow/core:tensorflow", ], diff --git a/tensorflow/core/framework/types.proto b/tensorflow/core/framework/types.proto index 051361bbed..c744594a49 100644 --- a/tensorflow/core/framework/types.proto +++ b/tensorflow/core/framework/types.proto @@ -6,6 +6,7 @@ option java_outer_classname = "TypesProtos"; option java_multiple_files = true; option java_package = "org.tensorflow.framework"; +// LINT.IfChange enum DataType { // Not a legal value for DataType. Used to indicate a DataType field // has not been set. @@ -58,3 +59,4 @@ enum DataType { DT_COMPLEX128_REF = 118; DT_HALF_REF = 119; } +// LINT.ThenChange(https://www.tensorflow.org/code/tensorflow/c/c_api.h,https://www.tensorflow.org/code/tensorflow/go/tensor.go) diff --git a/tensorflow/go/BUILD b/tensorflow/go/BUILD new file mode 100644 index 0000000000..d69233f4fe --- /dev/null +++ b/tensorflow/go/BUILD @@ -0,0 +1,22 @@ +# Description: +# Go API for TensorFlow. + +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/go/doc.go b/tensorflow/go/doc.go new file mode 100644 index 0000000000..4494d49556 --- /dev/null +++ b/tensorflow/go/doc.go @@ -0,0 +1,18 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tensorflow is a Go binding to TensorFlow. +// +// The API is subject to change and may break at any time. +package tensorflow diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go new file mode 100644 index 0000000000..3e43a39817 --- /dev/null +++ b/tensorflow/go/graph.go @@ -0,0 +1,38 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #include "tensorflow/c/c_api.h" +import "C" + +import ( + "runtime" +) + +// Graph represents a computation graph. Graphs may be shared between sessions. +type Graph struct { + c *C.TF_Graph +} + +// NewGraph returns a new Graph. +func NewGraph() *Graph { + g := &Graph{C.TF_NewGraph()} + runtime.SetFinalizer(g, (*Graph).finalizer) + return g +} + +func (g *Graph) finalizer() { + C.TF_DeleteGraph(g.c) +} diff --git a/tensorflow/go/lib.go b/tensorflow/go/lib.go new file mode 100644 index 0000000000..dcab7a90f8 --- /dev/null +++ b/tensorflow/go/lib.go @@ -0,0 +1,19 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #cgo LDFLAGS: -ltensorflow +// #cgo CFLAGS: -I${SRCDIR}/../../ +import "C" diff --git a/tensorflow/go/operation.go b/tensorflow/go/operation.go new file mode 100644 index 0000000000..19eb1a5816 --- /dev/null +++ b/tensorflow/go/operation.go @@ -0,0 +1,82 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #include <stdlib.h> +// #include "tensorflow/c/c_api.h" +import "C" +import "unsafe" + +// Operation that has been added to the graph. +type Operation struct { + c *C.TF_Operation +} + +// Port represents a specific input or output of an operation, e.g. to specify +// the specific output to pass as an input to a new op. +// +// Note the difference in naming convention: Port corresponds to Tensor/Output +// in the Python API. +type Port struct { + Op *Operation + Index int +} + +func (p *Port) c() C.TF_Port { + return C.TF_Port{oper: p.Op.c, index: C.int(p.Index)} +} + +// opBuilder is for use by the generated op code to create new Operations. +// Build() must be called for any in-progress Operation, or else we leak. +type opBuilder struct { + c *C.TF_OperationDescription +} + +func newOpBuilder(g *Graph, typ string, name string) *opBuilder { + opType := C.CString(typ) + opName := C.CString(name) + b := &opBuilder{c: C.TF_NewOperation(g.c, opType, opName)} + C.free(unsafe.Pointer(opType)) + C.free(unsafe.Pointer(opName)) + return b +} + +func (b *opBuilder) SetAttrTensor(name string, t *Tensor) error { + status := newStatus() + attrName := C.CString(name) + C.TF_SetAttrTensor(b.c, attrName, t.c(), status.c) + C.free(unsafe.Pointer(attrName)) + return status.Err() +} + +func (b *opBuilder) SetAttrType(name string, typ DataType) { + attrName := C.CString(name) + C.TF_SetAttrType(b.c, attrName, C.TF_DataType(typ)) + C.free(unsafe.Pointer(attrName)) +} + +func (b *opBuilder) AddInput(port Port) { + C.TF_AddInput(b.c, port.c()) +} + +func (b *opBuilder) Build() (*Operation, error) { + status := newStatus() + op := &Operation{c: C.TF_FinishOperation(b.c, status.c)} + if err := status.Err(); err != nil { + return nil, err + } + b.c = nil + return op, nil +} diff --git a/tensorflow/go/session.go b/tensorflow/go/session.go new file mode 100644 index 0000000000..98a87602d1 --- /dev/null +++ b/tensorflow/go/session.go @@ -0,0 +1,187 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #include <stdlib.h> +// #include "tensorflow/c/c_api.h" +import "C" + +import ( + "errors" + "runtime" + "sync" + "unsafe" +) + +// Session drives a TensorFlow graph computation. +// +// When a Session is created with a given target, a new Session object is bound +// to the universe of resources specified by that target. Those resources are +// available to this session to perform computation described in the GraphDef. +// After creating the session with a graph, the caller uses the Run() API to +// perform the computation and potentially fetch outputs as Tensors. +// A Session allows concurrent calls to Run(). +type Session struct { + c *C.TF_SessionWithGraph + + // For ensuring that: + // - Close() blocks on all Run() calls to complete. + // - Close() can be called multiple times. + wg sync.WaitGroup + mu sync.Mutex +} + +// NewSession creates a new execution session with the associated graph. +// options may be nil to use the default options. +func NewSession(graph *Graph, options *SessionOptions) (*Session, error) { + status := newStatus() + cOpt := options.c() + cSess := C.TF_NewSessionWithGraph(graph.c, cOpt, status.c) + C.TF_DeleteSessionOptions(cOpt) + if err := status.Err(); err != nil { + return nil, err + } + + s := &Session{c: cSess} + runtime.SetFinalizer(s, func(s *Session) { s.Close() }) + return s, nil +} + +// Run the graph with the associated session starting with the supplied inputs. +// inputs and outputs may be set to nil. Runs, but does not return Tensors +// for operations specified in targets. +// +// On success, returns the Tensor outputs in the same order as supplied in +// the outputs argument. If outputs is set to nil, the returned Tensor outputs +// is empty. +func (s *Session) Run(inputs map[Port]*Tensor, outputs []Port, targets []*Operation) ([]*Tensor, error) { + s.mu.Lock() + if s.c == nil { + s.mu.Unlock() + return nil, errors.New("session is closed") + } + s.wg.Add(1) + s.mu.Unlock() + defer s.wg.Done() + + var inputPorts []C.TF_Port + var inputValues []*C.TF_Tensor + if inputs != nil { + for port, tensor := range inputs { + inputPorts = append(inputPorts, port.c()) + inputValues = append(inputValues, tensor.c()) + } + } + + var outputPorts []C.TF_Port + for _, port := range outputs { + outputPorts = append(outputPorts, port.c()) + } + outputValues := make([]*C.TF_Tensor, len(outputs)) + var cTargets []*C.TF_Operation + for _, target := range targets { + cTargets = append(cTargets, target.c) + } + + status := newStatus() + var inputPortsPtr *C.TF_Port + var inputValuesPtr **C.TF_Tensor + if len(inputPorts) > 0 { + inputPortsPtr = &inputPorts[0] + inputValuesPtr = &inputValues[0] + } + + var outputPortsPtr *C.TF_Port + var outputValuesPtr **C.TF_Tensor + if len(outputPorts) > 0 { + outputPortsPtr = &outputPorts[0] + outputValuesPtr = &outputValues[0] + } + + var cTargetsPtr **C.TF_Operation + if len(cTargets) > 0 { + cTargetsPtr = &cTargets[0] + } + + C.TF_SessionRun(s.c, nil, inputPortsPtr, inputValuesPtr, C.int(len(inputPorts)), outputPortsPtr, outputValuesPtr, C.int(len(outputPorts)), cTargetsPtr, C.int(len(cTargets)), nil, status.c) + if err := status.Err(); err != nil { + return nil, err + } + + var tensors []*Tensor + for _, val := range outputValues { + tensors = append(tensors, newTensorFromC(val)) + C.TF_DeleteTensor(val) + } + + return tensors, nil +} + +// Close a session. This contacts any other processes associated with this +// session, if applicable. Blocks until all previous calls to Run have returned. +func (s *Session) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + s.wg.Wait() + if s.c == nil { + return nil + } + status := newStatus() + C.TF_CloseSessionWithGraph(s.c, status.c) + if err := status.Err(); err != nil { + return err + } + C.TF_DeleteSessionWithGraph(s.c, status.c) + s.c = nil + return status.Err() +} + +// SessionOptions contains configuration information for a session. +type SessionOptions struct { + // Target indicates the TensorFlow runtime to connect to. + // + // If 'target' is empty or unspecified, the local TensorFlow runtime + // implementation will be used. Otherwise, the TensorFlow engine + // defined by 'target' will be used to perform all computations. + // + // "target" can be either a single entry or a comma separated list + // of entries. Each entry is a resolvable address of one of the + // following formats: + // local + // ip:port + // host:port + // ... other system-specific formats to identify tasks and jobs ... + // + // NOTE: at the moment 'local' maps to an in-process service-based + // runtime. + // + // Upon creation, a single session affines itself to one of the + // remote processes, with possible load balancing choices when the + // "target" resolves to a list of possible processes. + // + // If the session disconnects from the remote process during its + // lifetime, session calls may fail immediately. + Target string +} + +// c converts the SessionOptions to the C API's TF_SessionOptions. Callers must +// deallocate by calling C.TF_DeleteSessionOptions(). +func (o *SessionOptions) c() *C.TF_SessionOptions { + opt := C.TF_NewSessionOptions() + t := C.CString(o.Target) + C.TF_SetTarget(opt, t) + C.free(unsafe.Pointer(t)) + return opt +} diff --git a/tensorflow/go/session_test.go b/tensorflow/go/session_test.go new file mode 100644 index 0000000000..78f6bccfd6 --- /dev/null +++ b/tensorflow/go/session_test.go @@ -0,0 +1,114 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "reflect" + "testing" +) + +func Placeholder(g *Graph, name string, dt DataType) (Port, error) { + b := newOpBuilder(g, "Placeholder", name) + b.SetAttrType("dtype", dt) + op, err := b.Build() + if err != nil { + return Port{}, err + } + return Port{op, 0}, nil +} + +func Neg(g *Graph, name string, port Port) (Port, error) { + b := newOpBuilder(g, "Neg", name) + b.AddInput(port) + op, err := b.Build() + if err != nil { + return Port{}, err + } + return Port{op, 0}, nil +} + +func createTestGraph(t *testing.T, dt DataType) (*Graph, Port, Port) { + g := NewGraph() + inp, err := Placeholder(g, "p1", dt) + if err != nil { + t.Fatalf("Placeholder() for %v: %v", dt, err) + } + out, err := Neg(g, "neg1", inp) + if err != nil { + t.Fatalf("Neg() for %v: %v", dt, err) + } + return g, inp, out +} + +func TestSessionRunNeg(t *testing.T) { + var tests = []struct { + input interface{} + expected interface{} + }{ + {int64(1), int64(-1)}, + {[]float64{-1, -2, 3}, []float64{1, 2, -3}}, + {[][]float32{{1, -2}, {-3, 4}}, [][]float32{{-1, 2}, {3, -4}}}, + } + + for _, test := range tests { + t1, err := NewTensor(test.input) + if err != nil { + t.Fatalf("NewTensor(%v): %v", test.input, err) + } + graph, inp, out := createTestGraph(t, t1.DataType()) + s, err := NewSession(graph, &SessionOptions{}) + if err != nil { + t.Fatalf("NewSession() for %v: %v", test.input, err) + } + output, err := s.Run(map[Port]*Tensor{inp: t1}, []Port{out}, []*Operation{out.Op}) + if err != nil { + t.Fatalf("Run() for %v: %v", test.input, err) + } + if len(output) != 1 { + t.Errorf("%v: got %d outputs, want 1", test.input, len(output)) + continue + } + val := output[0].Value() + if !reflect.DeepEqual(test.expected, val) { + t.Errorf("got %v, want %v", val, test.expected) + } + if err := s.Close(); err != nil { + t.Errorf("Close(): %v", err) + } + } +} + +func TestConcurrency(t *testing.T) { + tensor, err := NewTensor(int64(1)) + if err != nil { + t.Fatalf("NewTensor(): %v", err) + } + + graph, inp, out := createTestGraph(t, tensor.DataType()) + s, err := NewSession(graph, &SessionOptions{}) + if err != nil { + t.Fatalf("NewSession(): %v", err) + } + for i := 0; i < 100; i++ { + // Session may close before Run() starts, so we don't check the error. + go s.Run(map[Port]*Tensor{inp: tensor}, []Port{out}, []*Operation{out.Op}) + } + if err = s.Close(); err != nil { + t.Errorf("Close() 1: %v", err) + } + if err = s.Close(); err != nil { + t.Errorf("Close() 2: %v", err) + } +} diff --git a/tensorflow/go/status.go b/tensorflow/go/status.go new file mode 100644 index 0000000000..a1f7ed5481 --- /dev/null +++ b/tensorflow/go/status.go @@ -0,0 +1,65 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #include "tensorflow/c/c_api.h" +import "C" + +import "runtime" + +type code C.TF_Code + +// status holds error information returned by TensorFlow. We convert all +// TF statuses to Go errors. +type status struct { + c *C.TF_Status +} + +func newStatus() *status { + s := &status{C.TF_NewStatus()} + runtime.SetFinalizer(s, (*status).finalizer) + return s +} + +func (s *status) finalizer() { + C.TF_DeleteStatus(s.c) +} + +func (s *status) Code() code { + return code(C.TF_GetCode(s.c)) +} + +func (s *status) String() string { + return C.GoString(C.TF_Message(s.c)) +} + +// Err converts the status to a Go error and returns nil if the status is OK. +func (s *status) Err() error { + if s == nil || s.Code() == C.TF_OK { + return nil + } + return (*statusError)(s) +} + +// statusError is distinct from status because it fulfills the error interface. +// status itself may have a TF_OK code and is not always considered an error. +// +// TODO(jhseu): Make public, rename to Error, and provide a way for users to +// check status codes. +type statusError status + +func (s *statusError) Error() string { + return (*status)(s).String() +} diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go new file mode 100644 index 0000000000..e364e80f86 --- /dev/null +++ b/tensorflow/go/tensor.go @@ -0,0 +1,259 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +// #include <string.h> +// #include "tensorflow/c/c_api.h" +import "C" + +import ( + "bytes" + "encoding/binary" + "fmt" + "reflect" + "unsafe" +) + +// DataType holds the type for a scalar value. E.g., one slot in a tensor. +// The values here are identical to corresponding values in types.proto. +type DataType C.TF_DataType + +// Tensor holds a multi-dimensional array of elements of a single data type. +type Tensor struct { + // We create TF_Tensor on demand rather than keep a handle to C.TF_Tensor + // because many functions, such as Session.Run() and Operations take + // ownership of the C.TF_Tensor. Translating on-demand provides for a safe + // API. + // + // A memcpy is required because cgo rules prohibit us from maintaining + // a pointer to Go memory. + // call: https://golang.org/cmd/cgo/ + buf *bytes.Buffer + dt DataType + shape []int64 +} + +// NewTensor converts from a Go value to a Tensor. Valid values are scalars, +// slices, and arrays. Every element of a slice must have the same length so +// that the resulting Tensor has a valid shape. +func NewTensor(value interface{}) (*Tensor, error) { + val := reflect.ValueOf(value) + rank, dataType, err := rankAndDataTypeOf(val.Type()) + if err != nil { + return nil, err + } + t := &Tensor{buf: bytes.NewBuffer(nil), dt: dataType, shape: make([]int64, rank)} + if err = encodeTensor(t.buf, t.shape, val); err != nil { + return nil, err + } + return t, nil +} + +// newTensorFromC converts from a C.TF_Tensor to a Tensor. +func newTensorFromC(ct *C.TF_Tensor) *Tensor { + t := &Tensor{dt: DataType(C.TF_TensorType(ct))} + numDims := int(C.TF_NumDims(ct)) + for i := 0; i < numDims; i++ { + t.shape = append(t.shape, int64(C.TF_Dim(ct, C.int(i)))) + } + b := make([]byte, int(C.TF_TensorByteSize(ct))) + if len(b) > 0 { + C.memcpy(unsafe.Pointer(&b[0]), C.TF_TensorData(ct), C.size_t(len(b))) + } + t.buf = bytes.NewBuffer(b) + return t +} + +// DataType returns the scalar datatype of the Tensor. +func (t *Tensor) DataType() DataType { + return t.dt +} + +// Shape returns the shape of the Tensor. +func (t *Tensor) Shape() []int64 { + return t.shape +} + +// Value converts the Tensor to a Go value. For now, not all Tensor types are +// supported, and this function may panic if it encounters an unsupported +// DataType. +// +// The type of the output depends on the Tensor type and rank. For example: +// Tensor(int64, 0): int64 +// Tensor(float64, 3): [][][]float64 +func (t *Tensor) Value() interface{} { + typ, err := typeOf(t.DataType(), t.Shape()) + if err != nil { + panic(err) + } + val := reflect.New(typ) + if err := decodeTensor(t.buf, t.Shape(), typ, val); err != nil { + panic(err) + } + return reflect.Indirect(val).Interface() +} + +// c converts the Tensor to a *C.TF_Tensor. Callers must take ownership of +// the *C.TF_Tensor, either by passing ownership to the C API or explicitly +// calling C.TF_DeleteTensor() on it. +func (t *Tensor) c() *C.TF_Tensor { + var shapePtr *C.int64_t + if len(t.shape) > 0 { + shapePtr = (*C.int64_t)(unsafe.Pointer(&t.shape[0])) + } + tensor := C.TF_AllocateTensor(C.TF_DataType(t.dt), shapePtr, C.int(len(t.shape)), C.size_t(t.buf.Len())) + if t.buf.Len() > 0 { + slice := t.buf.Bytes() // https://github.com/golang/go/issues/14210 + C.memcpy(C.TF_TensorData(tensor), unsafe.Pointer(&slice[0]), C.size_t(t.buf.Len())) + } + return tensor +} + +// deleteCTensor only exists to delete C.TF_Tensors in tests. go test doesn't +// support cgo. +func deleteCTensor(ct *C.TF_Tensor) { + C.TF_DeleteTensor(ct) +} + +var types = []struct { + typ reflect.Type + dataType C.TF_DataType +}{ + {reflect.TypeOf(float32(0)), C.TF_FLOAT}, + {reflect.TypeOf(float64(0)), C.TF_DOUBLE}, + {reflect.TypeOf(int32(0)), C.TF_INT32}, + {reflect.TypeOf(uint8(0)), C.TF_UINT8}, + {reflect.TypeOf(int16(0)), C.TF_INT16}, + {reflect.TypeOf(int8(0)), C.TF_INT8}, + {reflect.TypeOf(""), C.TF_STRING}, + {reflect.TypeOf(complex(float32(0), float32(0))), C.TF_COMPLEX64}, + {reflect.TypeOf(int64(0)), C.TF_INT64}, + {reflect.TypeOf(false), C.TF_BOOL}, + {reflect.TypeOf(uint16(0)), C.TF_UINT16}, + {reflect.TypeOf(complex(float64(0), float64(0))), C.TF_COMPLEX128}, +} + +// rankAndDataTypeOf returns the data type and rank of a Go type for use when +// encoding. We fetch them separately from encoding to support 0-sized +// dimensions. +func rankAndDataTypeOf(typ reflect.Type) (int, DataType, error) { + rank := 0 + elem := typ + for ; elem.Kind() == reflect.Array || elem.Kind() == reflect.Slice; elem = elem.Elem() { + rank++ + } + for _, t := range types { + if elem.Kind() == t.typ.Kind() { + return rank, DataType(t.dataType), nil + } + } + return 0, DataType(0), fmt.Errorf("unsupported type %v", typ) +} + +// typeOf converts from a DataType and Shape to the equivalent Go type. +func typeOf(dt DataType, shape []int64) (reflect.Type, error) { + var ret reflect.Type + for _, t := range types { + if dt == DataType(t.dataType) { + ret = t.typ + break + } + } + if ret == nil { + return nil, fmt.Errorf("DataType %v unsupported", dt) + } + for _ = range shape { + ret = reflect.SliceOf(ret) + } + return ret, nil +} + +// encodeTensor writes v to the specified buffer using the format specified in +// c_api.h +func encodeTensor(buf *bytes.Buffer, shape []int64, v reflect.Value) error { + switch v.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + if err := binary.Write(buf, nativeEndian, v.Interface()); err != nil { + return err + } + + case reflect.Array, reflect.Slice: + // If slice elements are slices, verify that all of them have the same size. + // Go's type system makes that guarantee for arrays. + if v.Len() > 0 && v.Type().Elem().Kind() == reflect.Slice { + expected := v.Index(0).Len() + for i := 1; i < v.Len(); i++ { + if v.Index(i).Len() != expected { + return fmt.Errorf("mismatched slice lengths: %d and %d", v.Index(i).Len(), expected) + } + } + } + + shape[0] = int64(v.Len()) + for i := 0; i < v.Len(); i++ { + err := encodeTensor(buf, shape[1:], v.Index(i)) + if err != nil { + return err + } + } + + default: + return fmt.Errorf("unsupported type %v", v.Type()) + } + return nil +} + +// decodeTensor decodes the Tensor from the buffer to ptr using the format +// specified in c_api.h +func decodeTensor(buf *bytes.Buffer, shape []int64, typ reflect.Type, ptr reflect.Value) error { + switch typ.Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + if err := binary.Read(buf, nativeEndian, ptr.Interface()); err != nil { + return err + } + + case reflect.Slice: + val := reflect.Indirect(ptr) + val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0]))) + for i := 0; i < val.Len(); i++ { + if err := decodeTensor(buf, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil { + return err + } + } + + default: + return fmt.Errorf("unsupported type %v", typ) + } + return nil +} + +// nativeEndian is the byte order for the local platform. Used to send back and +// forth Tensors with the C API. We test for endianness at runtime because +// some architectures can be booted into different endian modes. +var nativeEndian binary.ByteOrder + +func init() { + buf := [2]byte{} + *(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD) + + switch buf { + case [2]byte{0xCD, 0xAB}: + nativeEndian = binary.LittleEndian + case [2]byte{0xAB, 0xCD}: + nativeEndian = binary.BigEndian + default: + panic("Could not determine native endianness.") + } +} diff --git a/tensorflow/go/tensor_test.go b/tensorflow/go/tensor_test.go new file mode 100644 index 0000000000..630d613729 --- /dev/null +++ b/tensorflow/go/tensor_test.go @@ -0,0 +1,97 @@ +// Copyright 2016 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tensorflow + +import ( + "reflect" + "testing" +) + +func TestNewTensor(t *testing.T) { + var tests = []struct { + shape []int64 + value interface{} + }{ + {[]int64{}, int8(5)}, + {[]int64{}, int16(5)}, + {[]int64{}, int32(5)}, + {[]int64{}, int64(5)}, + {[]int64{}, int64(5)}, + {[]int64{}, uint8(5)}, + {[]int64{}, uint16(5)}, + {[]int64{}, float32(5)}, + {[]int64{}, float64(5)}, + {[]int64{}, complex(float32(5), float32(6))}, + {[]int64{}, complex(float64(5), float64(6))}, + {[]int64{1}, []float64{1}}, + {[]int64{1}, [1]float64{1}}, + {[]int64{3, 2}, [][]float64{{1, 2}, {3, 4}, {5, 6}}}, + {[]int64{2, 3}, [2][3]float64{{1, 2, 3}, {3, 4, 6}}}, + {[]int64{4, 3, 2}, [][][]float64{ + {{1, 2}, {3, 4}, {5, 6}}, + {{7, 8}, {9, 10}, {11, 12}}, + {{0, -1}, {-2, -3}, {-4, -5}}, + {{-6, -7}, {-8, -9}, {-10, -11}}, + }}, + {[]int64{2, 0}, [][]int64{{}, {}}}, + } + + var errorTests = []interface{}{ + struct{ a int }{5}, + new(int32), + new([]int32), + // native ints not supported + int(5), + []int{5}, + // uint32 and uint64 are not supported in TensorFlow + uint32(5), + []uint32{5}, + uint64(5), + []uint64{5}, + } + + for _, test := range tests { + tensor, err := NewTensor(test.value) + if err != nil { + t.Errorf("NewTensor(%v): %v", test.value, err) + continue + } + if !reflect.DeepEqual(test.shape, tensor.Shape()) { + t.Errorf("Tensor.Shape(): got %v, want %v", tensor.Shape(), test.shape) + } + + // Test that encode and decode gives the same value. We skip arrays because + // they're returned as slices. + if reflect.TypeOf(test.value).Kind() != reflect.Array { + cTensor := tensor.c() + gotTensor := newTensorFromC(cTensor) + deleteCTensor(cTensor) + got := gotTensor.Value() + if !reflect.DeepEqual(test.value, got) { + t.Errorf("encode/decode: got %v, want %v", got, test.value) + } + } + } + + for _, test := range errorTests { + tensor, err := NewTensor(test) + if err == nil { + t.Errorf("NewTensor(%v): %v", test, err) + } + if tensor != nil { + t.Errorf("NewTensor(%v) = %v, want nil", test, tensor) + } + } +} |