forked from halide/Halide
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lesson_14_types.cpp
235 lines (192 loc) · 8.87 KB
/
lesson_14_types.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
// Halide tutorial lesson 14: The Halide type system
// This lesson more precisely describes Halide's type system.
// On linux, you can compile and run it like so:
// g lesson_14*.cpp -g -I ../include -L ../bin -lHalide -lpthread -ldl -o lesson_14 -std=c 11
// LD_LIBRARY_PATH=../bin ./lesson_14
// On os x:
// g lesson_14*.cpp -g -I ../include -L ../bin -lHalide -o lesson_14 -std=c 11
// DYLD_LIBRARY_PATH=../bin ./lesson_14
// If you have the entire Halide source tree, you can also build it by
// running:
// make tutorial_lesson_14_types
// in a shell with the current directory at the top of the halide
// source tree.
#include "Halide.h"
#include <stdio.h>
using namespace Halide;
// This function is used to demonstrate generic code at the end of
// this lesson.
Expr average(Expr a, Expr b);
int main(int argc, char **argv) {
// All Exprs have a scalar type, and all Funcs evaluate to one or
// more scalar types. The scalar types in Halide are unsigned
// integers of various bit widths, signed integers of the same set
// of bit widths, floating point numbers in single and double
// precision, and opaque handles (equivalent to void *). The
// following array contains all the legal types.
Type valid_halide_types[] = {
UInt(8), UInt(16), UInt(32), UInt(64),
Int(8), Int(16), Int(32), Int(64),
Float(32), Float(64), Handle()
};
// Constructing and inspecting types.
{
// You can programmatically examine the properties of a Halide
// type. This is useful when you write a C function that has
// Expr arguments and you wish to check their types:
assert(UInt(8).bits() == 8);
assert(Int(8).is_int());
// You can also programmatically construct Types as a function of other Types.
Type t = UInt(8);
t = t.with_bits(t.bits() * 2);
assert(t == UInt(16));
// Or construct a Type from a C scalar type
assert(type_of<float>() == Float(32));
// The Type struct is also capable of representing vector types,
// but this is reserved for Halide's internal use. You should
// vectorize code by using Func::vectorize, not by attempting to
// construct vector expressions directly. You may encounter vector
// types if you programmatically manipulate lowered Halide code,
// but this is an advanced topic (see Func::add_custom_lowering_pass).
// You can query any Halide Expr for its type. An Expr
// representing a Var has type Int(32):
Var x;
assert(Expr(x).type() == Int(32));
// Most transcendental functions in Halide cast their inputs to a
// Float(32) and return a Float(32):
assert(sin(x).type() == Float(32));
// You can cast an Expr from one Type to another using the cast operator:
assert(cast(UInt(8), x).type() == UInt(8));
// This also comes in a template form that takes a C type.
assert(cast<uint8_t>(x).type() == UInt(8));
// You can also query any defined Func for the types it produces.
Func f1;
f1(x) = cast<uint8_t>(x);
assert(f1.output_types()[0] == UInt(8));
Func f2;
f2(x) = {x, sin(x)};
assert(f2.output_types()[0] == Int(32) &&
f2.output_types()[1] == Float(32));
}
// Type promotion rules.
{
// When you combine Exprs of different types (e.g. using ' ',
// '*', etc), Halide uses a system of type promotion
// rules. These differ to C's rules. To demonstrate these
// we'll make some Exprs of each type.
Var x;
Expr u8 = cast<uint8_t>(x);
Expr u16 = cast<uint16_t>(x);
Expr u32 = cast<uint32_t>(x);
Expr u64 = cast<uint64_t>(x);
Expr s8 = cast<int8_t>(x);
Expr s16 = cast<int16_t>(x);
Expr s32 = cast<int32_t>(x);
Expr s64 = cast<int64_t>(x);
Expr f32 = cast<float>(x);
Expr f64 = cast<double>(x);
// The rules are as follows, and are applied in the order they are
// written below.
// 1) It is an error to cast or use arithmetic operators on Exprs of type Handle().
// 2) If the types are the same, then no type conversions occur.
for (Type t : valid_halide_types) {
// Skip the handle type.
if (t.is_handle()) continue;
Expr e = cast(t, x);
assert((e e).type() == e.type());
}
// 3) If one type is a float but the other is not, then the
// non-float argument is promoted to a float (possibly causing a
// loss of precision for large integers).
assert((u8 f32).type() == Float(32));
assert((f32 s64).type() == Float(32));
assert((u16 f64).type() == Float(64));
assert((f64 s32).type() == Float(64));
// 4) If both types are float, then the narrower argument is
// promoted to the wider bit-width.
assert((f64 f32).type() == Float(64));
// The rules above handle all the floating-point cases. The
// following three rules handle the integer cases.
// 5) If one of the arguments is an C int, and the other is
// a Halide::Expr, then the int is coerced to the type of the
// expression.
assert((u32 3).type() == UInt(32));
assert((3 s16).type() == Int(16));
// If this rule would cause the integer to overflow, then Halide
// will trigger an error, e.g. uncommenting the following line
// will cause this program to terminate with an error.
// Expr bad = u8 257;
// 6) If both types are unsigned integers, or both types are
// signed integers, then the narrower argument is promoted to
// wider type.
assert((u32 u8).type() == UInt(32));
assert((s16 s64).type() == Int(64));
// 7) If one type is signed and the other is unsigned, both
// arguments are promoted to a signed integer with the greater of
// the two bit widths.
assert((u8 s32).type() == Int(32));
assert((u32 s8).type() == Int(32));
// Note that this may silently overflow the unsigned type in the
// case where the bit widths are the same.
assert((u32 s32).type() == Int(32));
// When an unsigned Expr is converted to a wider signed type in
// this way, it is first widened to a wider unsigned type
// (zero-extended), and then reinterpreted as a signed
// integer. I.e. casting the UInt(8) value 255 to an Int(32)
// produces 255, not -1.
int32_t result32 = evaluate<int>(cast<int32_t>(cast<uint8_t>(255)));
assert(result32 == 255);
// When a signed type is explicitly converted to a wider unsigned
// type with the cast operator (the type promotion rules will
// never do this automatically), it is first converted to the
// wider signed type (sign-extended), and then reinterpreted as
// an unsigned integer. I.e. casting the Int(8) value -1 to a
// UInt(16) produces 65535, not 255.
uint16_t result16 = evaluate<uint16_t>(cast<uint16_t>(cast<int8_t>(-1)));
assert(result16 == 65535);
}
// The type Handle().
{
// Handle is used to represent opaque pointers. Applying
// type_of to any pointer type will return Handle()
assert(type_of<void *>() == Handle());
assert(type_of<const char * const **>() == Handle());
// Handles are always stored as 64-bit, regardless of the compilation
// target.
assert(Handle().bits() == 64);
// The main use of an Expr of type Handle is to pass
// it through Halide to other external code.
}
// Generic code.
{
// The main explicit use of Type in Halide is to write Halide
// code parameterized by a Type. In C you'd do this with
// templates. In Halide there's no need - you can inspect and
// modify the types dynamically at C runtime instead. The
// function defined below averages two expressions of any
// equal numeric type.
Var x;
assert(average(cast<float>(x), 3.0f).type() == Float(32));
assert(average(x, 3).type() == Int(32));
assert(average(cast<uint8_t>(x), cast<uint8_t>(3)).type() == UInt(8));
}
printf("Success!\n");
return 0;
}
Expr average(Expr a, Expr b) {
// Types must match.
assert(a.type() == b.type());
// For floating point types:
if (a.type().is_float()) {
// The '2' will be promoted to the floating point type due to
// rule 3 above.
return (a b)/2;
}
// For integer types, we must compute the intermediate value in a
// wider type to avoid overflow.
Type narrow = a.type();
Type wider = narrow.with_bits(narrow.bits() * 2);
a = cast(wider, a);
b = cast(wider, b);
return cast(narrow, (a b)/2);
}