aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/constants.h
blob: b47f5243f008ecb2045456e4505d1a571fbed745 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_

#include <type_traits>

#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

namespace xla {

// Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
// determined at C++ run-time, rather than C++ compile-time.
// If 'value' is floating point but 'type' is not, or if 'value' is complex but
// 'type' is not, an error will be returned. This is to catch accidental
// truncation; in such cases, use an explicit cast.
template <typename T>
XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
  if (std::is_floating_point<T>::value &&
      !(primitive_util::IsFloatingPointType(type) ||
        primitive_util::IsComplexType(type))) {
    return builder->ReportError(InvalidArgument(
        "Invalid cast from floating point type to %s in ConstantR0WithType.",
        PrimitiveType_Name(type).c_str()));
  }
  if (std::is_same<T, complex64>::value &&
      !primitive_util::IsComplexType(type)) {
    return builder->ReportError(InvalidArgument(
        "Invalid cast from complex type to %s in ConstantR0WithType.",
        PrimitiveType_Name(type).c_str()));
  }
  switch (type) {
    case F16:
      return ConstantR0<half>(builder, static_cast<half>(value));
    case BF16:
      return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
    case F32:
      return ConstantR0<float>(builder, static_cast<float>(value));
    case F64:
      return ConstantR0<double>(builder, static_cast<double>(value));
    case C64:
      return ConstantR0<complex64>(builder, static_cast<complex64>(value));
    case U8:
      return ConstantR0<uint8>(builder, static_cast<uint8>(value));
    case U32:
      return ConstantR0<uint32>(builder, static_cast<uint32>(value));
    case U64:
      return ConstantR0<uint64>(builder, static_cast<uint64>(value));
    case S8:
      return ConstantR0<int8>(builder, static_cast<int8>(value));
    case S32:
      return ConstantR0<int32>(builder, static_cast<int32>(value));
    case S64:
      return ConstantR0<int64>(builder, static_cast<int64>(value));
    default:
      return builder->ReportError(
          InvalidArgument("Invalid type for ConstantR0WithType (%s).",
                          PrimitiveType_Name(type).c_str()));
  }
}

// Returns a scalar containing 'value' cast to the same run-time type as
// 'prototype'.
// If 'value' is floating point but 'prototype' is not, or if 'value' is complex
// 'prototype' is not, an error will be returned.
template <typename T>
XlaOp ScalarLike(XlaOp prototype, T value) {
  XlaBuilder* builder = prototype.builder();
  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
    return ConstantR0WithType(builder, shape.element_type(), value);
  });
}

// Returns a scalar with value '0' of 'type'.
XlaOp Zero(XlaBuilder* builder, PrimitiveType type);

// Returns a zero-filled tensor with shape `shape`.
XlaOp Zeros(XlaBuilder* builder, const Shape& shape);

// Returns a zero-filled tensor with the same shape as `prototype`.
XlaOp ZerosLike(XlaOp prototype);

// Returns a scalar with value '1' of 'type'.
XlaOp One(XlaBuilder* builder, PrimitiveType type);

// Returns the machine epsilon for floating-point type `type`, i.e.,
// the difference between 1.0 and the next representable value.
XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);

// Returns the minimum representable finite or infinite value for 'type'.
// Returns '-inf' for floating-point types.
XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);

// Returns the minimum representable finite value for 'type'. For a floating
// point type, this is equal to -MaxFiniteValue().
XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);

// Returns the maximum representable finite or infinite value for 'type'.
// Returns 'inf' for floating-point types.
XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);

// Returns the maximum representable finite value for 'type'.
XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);

}  // namespace xla

#endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_