diff options
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.h')
-rw-r--r-- | tensorflow/compiler/xla/literal_util.h | 19 |
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 125c268573..e02a96ae70 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -470,10 +470,11 @@ class Literal { // Populates literal values by calling the generator function for every cell // in this literal object. - template <typename NativeT> - Status Populate( - const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>& - generator); + // + // generator must be a callable of the type + // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible. + template <typename NativeT, typename FnType> + Status Populate(const FnType& generator); // Creates a Literal of the given dimensions with all elements set to the // given value. @@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); } -template <typename NativeT> -Status Literal::Populate( - const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>& - generator) { +template <typename NativeT, typename FnType> +Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); - int64 rank = ShapeUtil::Rank(this_shape); + const int64 rank = ShapeUtil::Rank(this_shape); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType<NativeT>()); tensorflow::gtl::MutableArraySlice<NativeT> data = @@ -1125,7 +1124,7 @@ Status Literal::Populate( ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector<int64>& indexes) { - int64 index = LinearIndex(indexes); + const int64 index = LinearIndex(indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; |