diff options
author | 2017-01-20 10:34:19 -0800 | |
---|---|---|
committer | 2017-01-20 10:48:54 -0800 | |
commit | 8499dc1ee56b40a6e1f16ccbf3073a12523a0f58 (patch) | |
tree | 0e1579a56668c5bc8aeec42541a409e5fb5adee4 /tensorflow/go/tensor.go | |
parent | 1b71d25167236617e2085c1a90f5be3d942cb22b (diff) |
Go: Support DT_BOOL tensors.
Fixes #6969
Change: 145101014
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r-- | tensorflow/go/tensor.go | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index c50e5e30aa..f96e796e5e 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -299,8 +299,16 @@ func byteSizeOfEncodedStrings(val interface{}) uintptr { // encodeTensor writes v to the specified buffer using the format specified in // c_api.h. Use stringEncoder for String tensors. -func encodeTensor(w io.Writer, v reflect.Value) error { +func encodeTensor(w *bytes.Buffer, v reflect.Value) error { switch v.Kind() { + case reflect.Bool: + b := byte(0) + if v.Bool() { + b = 1 + } + if err := w.WriteByte(b); err != nil { + return err + } case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { return err @@ -333,8 +341,14 @@ func encodeTensor(w io.Writer, v reflect.Value) error { // decodeTensor decodes the Tensor from the buffer to ptr using the format // specified in c_api.h. Use stringDecoder for String tensors. -func decodeTensor(r io.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { +func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { switch typ.Kind() { + case reflect.Bool: + b, err := r.ReadByte() + if err != nil { + return err + } + ptr.Elem().SetBool(b == 1) case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { return err |