Skip to content

Commit

Permalink
Merge pull request #21 from spcl/improve-adder
Browse files Browse the repository at this point in the history
Improve Adder
  • Loading branch information
definelicht authored Jan 20, 2022
2 parents f58acfc 500eab0 commit da5bae2
Showing 1 changed file with 55 additions and 38 deletions.
93 changes: 55 additions & 38 deletions device/ArithmeticOperations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 7,31 @@
#include "PipelinedAdd.h"

template <int bits>
inline bool IsMostSignificantBitSet(ap_uint<bits> const &num) {
bool IsMostSignificantBitSet(ap_uint<bits> const &num) {
#pragma HLS INLINE
return num.test(bits - 1);
}

template <int bits>
inline int CountLeadingZeros(ap_uint<bits> const &num) {
ap_uint<bits> DynamicRightShift(ap_uint<bits> num, ap_uint<hlslib::ConstLog2(bits)> const &shift_by) {
#pragma HLS INLINE
int leading_zeros = 0;
for (leading_zeros = 0; leading_zeros < bits; leading_zeros) {
if (num.test(bits - (leading_zeros 1))) {
break;
}
const auto kNumStages = hlslib::ConstLog2(bits);
for (int i = 0; i < kNumStages; i) {
#pragma HLS UNROLL
num = shift_by.test(i) ? (num >> (1 << i)) : num;
}
return leading_zeros;
return num;
}

template <int bits>
ap_uint<bits> DynamicLeftShift(ap_uint<bits> num, ap_uint<hlslib::ConstLog2(bits)> const &shift_by) {
#pragma HLS INLINE
const auto kNumStages = hlslib::ConstLog2(bits);
for (int i = 0; i < kNumStages; i) {
#pragma HLS UNROLL
num = shift_by.test(i) ? (num << (1 << i)) : num;
}
return num;
}

PackedFloat Multiply(PackedFloat const &a, PackedFloat const &b) {
Expand Down Expand Up @@ -56,11 66,20 @@ PackedFloat Multiply(PackedFloat const &a, PackedFloat const &b) {
// Does this correctly output the result if a and b are different signs?
// The mantissa of the result should depend on the sign bits of a and b
PackedFloat Add(PackedFloat const &a_in, PackedFloat const &b_in) {
const bool exp_are_equal = (a_in.GetExponent() == b_in.GetExponent());
const bool a_in_exp_strictly_larger = (a_in.GetExponent() > b_in.GetExponent());
const bool a_in_mant_is_zero = a_in.GetMantissa() == 0;
const bool b_in_mant_is_zero = b_in.GetMantissa() == 0;
const bool a_in_mantissa_larger = a_in.GetMantissa() >= b_in.GetMantissa();

// Retrieve once and for all to make sure there's no overhead from unpacking them
const bool _a_sign = a_in.GetSign();
const bool _b_sign = b_in.GetSign();
const Exponent _a_exponent = a_in.GetExponent();
const Exponent _b_exponent = b_in.GetExponent();
const MantissaFlat _a_mantissa = a_in.GetMantissa();
const MantissaFlat _b_mantissa = b_in.GetMantissa();

const bool exp_are_equal = (_a_exponent == _b_exponent);
const bool a_in_exp_strictly_larger = (_a_exponent > _b_exponent);
const bool a_in_mant_is_zero = _a_mantissa == 0;
const bool b_in_mant_is_zero = _b_mantissa == 0;
const bool a_in_mantissa_larger = _a_mantissa >= _b_mantissa;

// Plain comparison of exponent and mantissa
const bool a_larger_if_both_nonzero = a_in_exp_strictly_larger || (exp_are_equal && a_in_mantissa_larger);
Expand All @@ -71,47 90,47 @@ PackedFloat Add(PackedFloat const &a_in, PackedFloat const &b_in) {

// We always have a >= b to simplify the code
// a is zero iff b is zero
const PackedFloat a = a_is_larger ? a_in : b_in;
const PackedFloat b = a_is_larger ? b_in : a_in;
const bool a_is_zero = a_is_larger ? a_in_mant_is_zero : b_in_mant_is_zero;
const Exponent a_exponent = a_is_larger ? _a_exponent : _b_exponent;
const Exponent b_exponent = a_is_larger ? _b_exponent : _a_exponent;
const MantissaFlat a_mantissa = a_is_larger ? _a_mantissa : _b_mantissa;
const MantissaFlat b_mantissa = a_is_larger ? _b_mantissa : _a_mantissa;
const bool a_sign = a_is_larger ? _a_sign : _b_sign;
const bool b_sign = a_is_larger ? _b_sign : _a_sign;

const ap_uint<kMantissaBits> a_mantissa(a.GetMantissa());
const ap_uint<kMantissaBits> b_mantissa(b.GetMantissa());
const bool a_is_zero = a_is_larger ? a_in_mant_is_zero : b_in_mant_is_zero;

#ifndef HLSLIB_SYNTHESIS
// We better not be getting subnormal inputs
assert(a.IsZero() || IsMostSignificantBitSet(a_mantissa));
assert(b.IsZero() || IsMostSignificantBitSet(b_mantissa));
assert(a_in.IsZero() || IsMostSignificantBitSet(a_in.GetMantissa()));
assert(b_in.IsZero() || IsMostSignificantBitSet(b_in.GetMantissa()));
// a is zero => b is zero
assert(!a_is_zero || (a_is_zero && a_in_mant_is_zero && b_in_mant_is_zero));
#endif

const bool subtraction = a.GetSignBit() != b.GetSignBit();
Exponent res_exponent = a.GetExponent();
const Exponent shift_m = a.GetExponent() - b.GetExponent();
const bool subtraction = a_sign != b_sign;
Exponent res_exponent = a_exponent;
using Shift = ap_uint<8 * sizeof(Exponent) - 2>;
const Shift shift_m = a_exponent - b_exponent;

// Figure out how much we need to shift by
// Xilinx permits signed shifts
// We want to keep an extra bit of precision (LSB) to properly round the output
// We also want an extra bit of range (MSB) to track overflow
// The names in the following code segment have _msb/_lsb suffix if they have the extra msb/lsb respectively
constexpr int lsb_bits = kMantissaBits;
auto a_mantissa_shifted = static_cast<ap_uint<kMantissaBits 1 lsb_bits>>(a_mantissa) << lsb_bits;
auto b_mantissa_shifted = (static_cast<ap_uint<kMantissaBits 1 lsb_bits>>(b_mantissa) << lsb_bits) >> shift_m;
using MantissaExtended = ap_uint<2 * kMantissaBits 1>;
auto a_mantissa_shifted = MantissaExtended(a_mantissa) << kMantissaBits;
auto b_mantissa_shifted = DynamicRightShift(MantissaExtended(b_mantissa) << kMantissaBits, shift_m);

// Now we can add up the aligned mantissas
// ==== Add/Sub mantissas ====
// We cannot truncate yet because of the renormalization step
const ap_uint<kMantissaBits 1 lsb_bits> ab_sum_lsb_msb =
PipelinedAdd<kMantissaBits lsb_bits>(a_mantissa_shifted, b_mantissa_shifted);
const MantissaExtended ab_sum_lsb_msb = PipelinedAdd<2 * kMantissaBits>(a_mantissa_shifted, b_mantissa_shifted);

// This returns an ap_int but the answer is always positive so the MSB is never set
// Xilinx manual states signed <-> unsigned ignores the sign and converts bit for bit
// Widening assignments and right shifts of ap_int are sign extended so we specify the casting route
assert(a_mantissa_shifted >= b_mantissa_shifted);
ap_uint<kMantissaBits 1 lsb_bits> ab_diff_lsb_msb =
static_cast<ap_uint<kMantissaBits 1 lsb_bits>>(PipelinedSub(a_mantissa_shifted, b_mantissa_shifted));
#pragma HLS BIND_OP variable = ab_diff_lsb_msb op = sub impl = fabric latency = 4
MantissaExtended ab_diff_lsb_msb = MantissaExtended(PipelinedSub(a_mantissa_shifted, b_mantissa_shifted));
assert(!IsMostSignificantBitSet(ab_diff_lsb_msb));

// ==== overflow check ====
Expand All @@ -122,20 141,18 @@ PackedFloat Add(PackedFloat const &a_in, PackedFloat const &b_in) {

const bool addition_overflowed = IsMostSignificantBitSet(_res_mantissa_lsb_msb);
// We're still holding onto the extra lsb
const ap_uint<kMantissaBits lsb_bits> res_mantissa_lsb =
const MantissaExtended res_mantissa_lsb =
addition_overflowed ? (_res_mantissa_lsb_msb >> 1) : _res_mantissa_lsb_msb;
res_exponent = res_exponent (addition_overflowed ? 1 : 0);

// ==== Renormalize / Underflow ====
// Normalize the mantissa
bool res_nonzero = res_mantissa_lsb != 0;
const Exponent leading_zeros =
res_nonzero ? CountLeadingZeros(static_cast<ap_uint<kMantissaBits>>(res_mantissa_lsb >> kMantissaBits))
: kMantissaBits;
const ap_uint<hlslib::ConstLog2(kMantissaBits)> leading_zeros =
MantissaFlat(res_mantissa_lsb >> kMantissaBits).countLeadingZeros();

// Left shift by the number of leading zeros and truncate the lsb now
ap_uint<kMantissaBits> res_mantissa =
(res_nonzero ? (res_mantissa_lsb << leading_zeros) : decltype(res_mantissa_lsb)(0)) >> lsb_bits;
ap_uint<kMantissaBits> res_mantissa = DynamicLeftShift(res_mantissa_lsb, leading_zeros) >> kMantissaBits;

// We need to watch for underflow here
const bool underflow = res_exponent < std::numeric_limits<Exponent>::min() leading_zeros;
Expand All @@ -156,7 173,7 @@ PackedFloat Add(PackedFloat const &a_in, PackedFloat const &b_in) {
result.SetMantissa(res_mantissa);
result.SetExponent(res_exponent);
// Sign will be the same as whatever is the largest number
result.SetSign(a.GetSignBit());
result.SetSign(a_sign);

return a_is_zero ? PackedFloat::Zero() : result;
}
Expand Down

0 comments on commit da5bae2

Please sign in to comment.