aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/lib/constants.cc
blob: 031d62e4ffef188082303a28866bbc72a154e9b1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/client/lib/constants.h"

#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/util.h"

namespace xla {

XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
  return ConstantLiteral(builder, LiteralUtil::Zero(type));
}

XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
  return Broadcast(Zero(builder, shape.element_type()),
                   AsInt64Slice(shape.dimensions()));
}

XlaOp ZerosLike(XlaOp prototype) {
  XlaBuilder* builder = prototype.builder();
  return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
    TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
    return Zeros(builder, shape);
  });
}

XlaOp One(XlaBuilder* builder, PrimitiveType type) {
  return ConstantLiteral(builder, LiteralUtil::One(type));
}

XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
  switch (type) {
    case F16:
      return ConstantR0<Eigen::half>(
          builder,
          static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
    case BF16:
      return ConstantR0<bfloat16>(builder, bfloat16::epsilon());
    case F32:
      return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
    case F64:
      return ConstantR0<double>(builder,
                                std::numeric_limits<double>::epsilon());
    default:
      return builder->ReportError(InvalidArgument(
          "Invalid type for Epsilon (%s).", PrimitiveType_Name(type).c_str()));
  }
}

XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
  return ConstantLiteral(builder, LiteralUtil::MinValue(type));
}

XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
  switch (type) {
    case F16:
      return ConstantR0<Eigen::half>(builder,
                                     Eigen::NumTraits<Eigen::half>::lowest());
    case BF16:
      return ConstantR0<bfloat16>(builder, bfloat16::lowest());
    case F32:
      return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
    case F64:
      return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
    default:
      return MinValue(builder, type);
  }
}

XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
  return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
}

XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
  switch (type) {
    case F16:
      return ConstantR0<Eigen::half>(builder,
                                     Eigen::NumTraits<Eigen::half>::highest());
    case BF16:
      return ConstantR0<bfloat16>(builder, bfloat16::highest());
    case F32:
      return ConstantR0<float>(builder, std::numeric_limits<float>::max());
    case F64:
      return ConstantR0<double>(builder, std::numeric_limits<double>::max());
    default:
      return MaxValue(builder, type);
  }
}

}  // namespace xla