aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/go/tensor.go
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2017-12-06 18:43:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-06 18:47:41 -0800
commitfe8406149feec453250905965a14285465cd2063 (patch)
treebe3cd75d543f3c0f29f368da61d915abbae7fcbf /tensorflow/go/tensor.go
parent8ad62af489df718992561710123bc8c037e7d17b (diff)
Merge changes from github.
PiperOrigin-RevId: 178185697
Diffstat (limited to 'tensorflow/go/tensor.go')
-rw-r--r--tensorflow/go/tensor.go17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index cd05e2aa0a..2d25c04dc9 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -328,6 +328,14 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
}
}
+ // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
+ if len(shape) == 1 && v.Len() > 0 {
+ switch v.Index(0).Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ return binary.Write(w, nativeEndian, v.Interface())
+ }
+ }
+
subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
err := encodeTensor(w, v.Index(i), subShape)
@@ -360,6 +368,15 @@ func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
case reflect.Slice:
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
+
+ // Optimization: if only one dimension is left we can use binary.Read() directly for this slice
+ if len(shape) == 1 && val.Len() > 0 {
+ switch val.Index(0).Kind() {
+ case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
+ return binary.Read(r, nativeEndian, val.Interface())
+ }
+ }
+
for i := 0; i < val.Len(); i++ {
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
return err