aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/literal_util.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/literal_util.h')
-rw-r--r--tensorflow/compiler/xla/literal_util.h19
1 files changed, 9 insertions, 10 deletions
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index 125c268573..e02a96ae70 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -470,10 +470,11 @@ class Literal {
// Populates literal values by calling the generator function for every cell
// in this literal object.
- template <typename NativeT>
- Status Populate(
- const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
- generator);
+ //
+ // generator must be a callable of the type
+ // NativeT(tensorflow::gtl::ArraySlice<int64> indexes) or compatible.
+ template <typename NativeT, typename FnType>
+ Status Populate(const FnType& generator);
// Creates a Literal of the given dimensions with all elements set to the
// given value.
@@ -1107,12 +1108,10 @@ void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
}
-template <typename NativeT>
-Status Literal::Populate(
- const std::function<NativeT(tensorflow::gtl::ArraySlice<int64> indexes)>&
- generator) {
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
const Shape& this_shape = shape();
- int64 rank = ShapeUtil::Rank(this_shape);
+ const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(this_shape.element_type() ==
primitive_util::NativeToPrimitiveType<NativeT>());
tensorflow::gtl::MutableArraySlice<NativeT> data =
@@ -1125,7 +1124,7 @@ Status Literal::Populate(
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](const std::vector<int64>& indexes) {
- int64 index = LinearIndex(indexes);
+ const int64 index = LinearIndex(indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
for (int64 i = 0; i < minor_dimension_size; ++i) {
minor_scan_indexes[stride_config.minor_dimension] = i;