diff options
author | 2017-12-06 18:43:24 -0800 | |
---|---|---|
committer | 2017-12-06 18:47:41 -0800 | |
commit | fe8406149feec453250905965a14285465cd2063 (patch) | |
tree | be3cd75d543f3c0f29f368da61d915abbae7fcbf /tensorflow/go/tensor.go | |
parent | 8ad62af489df718992561710123bc8c037e7d17b (diff) |
Merge changes from github.
PiperOrigin-RevId: 178185697
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r-- | tensorflow/go/tensor.go | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index cd05e2aa0a..2d25c04dc9 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -328,6 +328,14 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { } } + // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice + if len(shape) == 1 && v.Len() > 0 { + switch v.Index(0).Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return binary.Write(w, nativeEndian, v.Interface()) + } + } + subShape := shape[1:] for i := 0; i < v.Len(); i++ { err := encodeTensor(w, v.Index(i), subShape) @@ -360,6 +368,15 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. case reflect.Slice: val := reflect.Indirect(ptr) val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0]))) + + // Optimization: if only one dimension is left we can use binary.Read() directly for this slice + if len(shape) == 1 && val.Len() > 0 { + switch val.Index(0).Kind() { + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + return binary.Read(r, nativeEndian, val.Interface()) + } + } + for i := 0; i < val.Len(); i++ { if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil { return err |