aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Asim Shankar <ashankar@google.com>2017-01-20 10:34:19 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-20 10:48:54 -0800
commit8499dc1ee56b40a6e1f16ccbf3073a12523a0f58 (patch)
tree0e1579a56668c5bc8aeec42541a409e5fb5adee4 /tensorflow/go/tensor.go
parent1b71d25167236617e2085c1a90f5be3d942cb22b (diff)
Go: Support DT_BOOL tensors.
Fixes #6969 Change: 145101014
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go18
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index c50e5e30aa..f96e796e5e 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -299,8 +299,16 @@ func byteSizeOfEncodedStrings(val interface{}) uintptr {
// encodeTensor writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
-func encodeTensor(w io.Writer, v reflect.Value) error {
+func encodeTensor(w *bytes.Buffer, v reflect.Value) error {
switch v.Kind() {
+ case reflect.Bool:
+ b := byte(0)
+ if v.Bool() {
+ b = 1
+ }
+ if err := w.WriteByte(b); err != nil {
+ return err
+ }
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
@@ -333,8 +341,14 @@ func encodeTensor(w io.Writer, v reflect.Value) error {
// decodeTensor decodes the Tensor from the buffer to ptr using the format
// specified in c_api.h. Use stringDecoder for String tensors.
-func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
+func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
+ case reflect.Bool:
+ b, err := r.ReadByte()
+ if err != nil {
+ return err
+ }
+ ptr.Elem().SetBool(b == 1)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err