aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_helpers.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_helpers.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc142
1 files changed, 142 insertions, 0 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
new file mode 100644
index 0000000000..efb0facf7b
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file defines helper routines for Tla JIT compilation.
+
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+
+namespace tensorflow {
+
+xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::MinValue(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::Zero(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
+ DataType data_type) {
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ return b->ConstantLiteral(xla::LiteralUtil::One(type));
+}
+
+xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
+ xla::ComputationBuilder* b, DataType data_type, int64 value) {
+ xla::Literal literal;
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::U8:
+ literal = *xla::LiteralUtil::CreateR0<uint8>(value);
+ break;
+ case xla::U32:
+ literal = *xla::LiteralUtil::CreateR0<uint32>(value);
+ break;
+ case xla::U64:
+ literal = *xla::LiteralUtil::CreateR0<uint64>(value);
+ break;
+ case xla::S8:
+ literal = *xla::LiteralUtil::CreateR0<int8>(value);
+ break;
+ case xla::S32:
+ literal = *xla::LiteralUtil::CreateR0<int32>(value);
+ break;
+ case xla::S64:
+ literal = *xla::LiteralUtil::CreateR0<int64>(value);
+ break;
+ case xla::F32:
+ literal = *xla::LiteralUtil::CreateR0<float>(value);
+ break;
+ case xla::F64:
+ literal = *xla::LiteralUtil::CreateR0<double>(value);
+ break;
+ case xla::PRED:
+ LOG(FATAL) << "pred element type is not integral";
+ case xla::S16:
+ case xla::U16:
+ LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case xla::F16:
+ LOG(FATAL) << "f16 literals not yet implemented";
+ case xla::TUPLE:
+ LOG(FATAL) << "tuple element type is not integral";
+ case xla::OPAQUE:
+ LOG(FATAL) << "opaque element type is not integral";
+ default:
+ LOG(FATAL) << "unhandled element type " << type;
+ }
+ return b->ConstantLiteral(literal);
+}
+
+xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
+ DataType data_type,
+ double value) {
+ xla::Literal literal;
+ xla::PrimitiveType type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
+ switch (type) {
+ case xla::F32:
+ return b->ConstantR0<float>(static_cast<float>(value));
+ break;
+ case xla::F64:
+ return b->ConstantR0<double>(value);
+ break;
+ default:
+ LOG(FATAL) << "unhandled element type " << type;
+ }
+}
+
+/* static */ Status XlaHelpers::ReshapeLiteral(
+ const xla::Literal& input, gtl::ArraySlice<int64> dimensions,
+ xla::Literal* output) {
+ if (xla::ShapeUtil::IsTuple(input.shape())) {
+ return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
+ }
+ xla::Shape shape =
+ xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
+ int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
+ int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
+ if (elements_before != elements_after) {
+ return errors::InvalidArgument(
+ "Shapes before and after ReshapeLiteral have different numbers of "
+ "elements.");
+ }
+
+ *output = input;
+ output->mutable_shape()->Swap(&shape);
+ return Status::OK();
+}
+
+} // end namespace tensorflow