diff options
author | Kay Zhu <kayzhu@google.com> | 2017-12-15 18:11:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 18:15:09 -0800 |
commit | 713d45278491d792c525344de6038a61ebcb2136 (patch) | |
tree | 4c4d4063b82d415a448e0ff4646fbf6f88636876 /tensorflow/compiler/xla/service/interpreter | |
parent | 9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (diff) |
[XLA] Support Map in HloEvaluator, enable Interpreter to run
xla/tests:map_map_test which tests this change.
Additionally:
- templatize Evaluate* methods to specialize on both std::unique_ptr<Literal> and const Literal* type of input literal arguments.
- add ResetVisitState to DfsHloVisitor such that a visitor instance can traverse the same HLO graph more than once.
PiperOrigin-RevId: 179263540
Diffstat (limited to 'tensorflow/compiler/xla/service/interpreter')
-rw-r--r-- | tensorflow/compiler/xla/service/interpreter/executable.cc | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 9183a1d1bf..293cc2007e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -98,12 +98,10 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream( // Create the arguments as an vector of XLA literals std::vector<std::unique_ptr<Literal>> arg_literals; - std::vector<Literal*> arg_literals_ptrs; for (int64 p = 0; p < computation->num_parameters(); ++p) { // Create the input literal for the parameter HloInstruction* param = computation->parameter_instruction(p); arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); - arg_literals_ptrs.push_back(arg_literals.back().get()); // Copy in the data from the stream_executor buffers void* buffer = arg_literals.back()->MutableInternalData(); @@ -113,8 +111,9 @@ StatusOr<se::DeviceMemoryBase> InterpreterExecutable::ExecuteOnStream( // Execute the graph using the HloEvaluator. HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> output, - evaluator.Evaluate(*computation, arg_literals_ptrs)); + TF_ASSIGN_OR_RETURN( + std::unique_ptr<Literal> output, + evaluator.Evaluate<std::unique_ptr<Literal>>(*computation, arg_literals)); // Copy the result into the return buffer perftools::gputools::StreamExecutor* executor(stream->parent()); |