Skip to content

Commit

Permalink
change the SIMD operator classes to handle comparisons separately fro…
Browse files Browse the repository at this point in the history
…m the IfThenElse body
  • Loading branch information
paulbkoch committed Oct 18, 2024
1 parent 2b9035e commit f64b70f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 131 deletions.
61 changes: 22 additions & 39 deletions shared/libebm/compute/avx2_ebm/avx2_32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ inline Avx2_32_Float Log(const Avx2_32_Float& val) noexcept;

struct alignas(k_cAlignment) Avx2_32_Int final {
friend Avx2_32_Float;
friend inline Avx2_32_Float IfEqual(const Avx2_32_Int& cmp1,
const Avx2_32_Int& cmp2,
const Avx2_32_Float& trueVal,
const Avx2_32_Float& falseVal) noexcept;

using T = uint32_t;
using TPack = __m256i;
Expand Down Expand Up @@ -107,6 +103,10 @@ struct alignas(k_cAlignment) Avx2_32_Int final {
return Avx2_32_Int(_mm256_xor_si256(m_data, _mm256_set1_epi32(-1)));
}

friend inline Avx2_32_Int operator==(const Avx2_32_Int& left, const Avx2_32_Int& right) noexcept {
return Avx2_32_Int(_mm256_cmpeq_epi32(left.m_data, right.m_data));
}

inline Avx2_32_Int operator+(const Avx2_32_Int& other) const noexcept {
return Avx2_32_Int(_mm256_add_epi32(m_data, other.m_data));
}
Expand Down Expand Up @@ -262,6 +262,14 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return Avx2_32_Float(val) / other;
}

friend inline Avx2_32_Int operator==(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
return ReinterpretInt(Avx2_32_Float(_mm256_cmp_ps(left.m_data, right.m_data, _CMP_EQ_OQ)));
}

friend inline Avx2_32_Int operator<(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
return ReinterpretInt(Avx2_32_Float(_mm256_cmp_ps(left.m_data, right.m_data, _CMP_LT_OQ)));
}

friend inline Avx2_32_Int operator<=(const Avx2_32_Float& left, const Avx2_32_Float& right) noexcept {
return ReinterpretInt(Avx2_32_Float(_mm256_cmp_ps(left.m_data, right.m_data, _CMP_LE_OQ)));
}
Expand Down Expand Up @@ -524,14 +532,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
func(7, a0[7], a1[7], a2[7], a3[7], a4[7]);
}

friend inline Avx2_32_Float IfLess(const Avx2_32_Float& cmp1,
const Avx2_32_Float& cmp2,
const Avx2_32_Float& trueVal,
const Avx2_32_Float& falseVal) noexcept {
const __m256 mask = _mm256_cmp_ps(cmp1.m_data, cmp2.m_data, _CMP_LT_OQ);
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, mask));
}

friend inline Avx2_32_Float IfThenElse(
const Avx2_32_Int& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, ReinterpretFloat(cmp).m_data));
Expand All @@ -542,14 +542,6 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
return base + ReinterpretFloat(cmp & ReinterpretInt(addend));
}

friend inline Avx2_32_Float IfEqual(const Avx2_32_Float& cmp1,
const Avx2_32_Float& cmp2,
const Avx2_32_Float& trueVal,
const Avx2_32_Float& falseVal) noexcept {
const __m256 mask = _mm256_cmp_ps(cmp1.m_data, cmp2.m_data, _CMP_EQ_OQ);
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, mask));
}

friend inline Avx2_32_Float IfNaN(
const Avx2_32_Float& cmp, const Avx2_32_Float& trueVal, const Avx2_32_Float& falseVal) noexcept {
// rely on the fact that a == a can only be false if a is a NaN
Expand All @@ -558,15 +550,7 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
// use an AND with _mm256_and_si256 to select just the NaN bits, then compare to zero with
// _mm256_cmpeq_epi32, but that has an overall latency of 2 and a throughput of 0.83333, which is lower
// throughput, so experiment with this
return IfEqual(cmp, cmp, falseVal, trueVal);
}

friend inline Avx2_32_Float IfEqual(const Avx2_32_Int& cmp1,
const Avx2_32_Int& cmp2,
const Avx2_32_Float& trueVal,
const Avx2_32_Float& falseVal) noexcept {
const __m256i mask = _mm256_cmpeq_epi32(cmp1.m_data, cmp2.m_data);
return Avx2_32_Float(_mm256_blendv_ps(falseVal.m_data, trueVal.m_data, _mm256_castsi256_ps(mask)));
return IfThenElse(cmp == cmp, falseVal, trueVal);
}

static inline Avx2_32_Int ReinterpretInt(const Avx2_32_Float& val) noexcept {
Expand Down Expand Up @@ -658,20 +642,20 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
#endif // EXP_INT_SIMD
Avx2_32_Float result = Avx2_32_Float(_mm256_castsi256_ps(retInt));
if(bSpecialCaseZero) {
result = IfEqual(0.0, val, 1.0, result);
result = IfThenElse(0.0 == val, 1.0, result);
}
if(bOverflowPossible) {
if(bNegateInput) {
result = IfLess(val, static_cast<T>(-k_expOverflowPoint), std::numeric_limits<T>::infinity(), result);
result = IfThenElse(val < static_cast<T>(-k_expOverflowPoint), std::numeric_limits<T>::infinity(), result);
} else {
result = IfLess(static_cast<T>(k_expOverflowPoint), val, std::numeric_limits<T>::infinity(), result);
result = IfThenElse(static_cast<T>(k_expOverflowPoint) < val, std::numeric_limits<T>::infinity(), result);
}
}
if(bUnderflowPossible) {
if(bNegateInput) {
result = IfLess(static_cast<T>(-k_expUnderflowPoint), val, 0.0, result);
result = IfThenElse(static_cast<T>(-k_expUnderflowPoint) < val, 0.0, result);
} else {
result = IfLess(val, static_cast<T>(k_expUnderflowPoint), 0.0, result);
result = IfThenElse(val < static_cast<T>(k_expUnderflowPoint), 0.0, result);
}
}
if(bNaNPossible) {
Expand Down Expand Up @@ -715,13 +699,13 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
Avx2_32_Float result = Avx2_32_Float(_mm256_cvtepi32_ps(retInt));
if(bNaNPossible) {
if(bPositiveInfinityPossible) {
result = IfLess(val, std::numeric_limits<T>::infinity(), result, val);
result = IfThenElse(val < std::numeric_limits<T>::infinity(), result, val);
} else {
result = IfNaN(val, val, result);
}
} else {
if(bPositiveInfinityPossible) {
result = IfEqual(std::numeric_limits<T>::infinity(), val, val, result);
result = IfThenElse(std::numeric_limits<T>::infinity() == val, val, result);
}
}
if(bNegateOutput) {
Expand All @@ -730,13 +714,12 @@ struct alignas(k_cAlignment) Avx2_32_Float final {
result = FusedMultiplyAdd(result, k_logMultiple, addLogSchraudolphTerm);
}
if(bZeroPossible) {
result = IfLess(val,
std::numeric_limits<T>::min(),
result = IfThenElse(val < std::numeric_limits<T>::min(),
bNegateOutput ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(),
result);
}
if(bNegativePossible) {
result = IfLess(val, T{0}, std::numeric_limits<T>::quiet_NaN(), result);
result = IfThenElse(val < T{0}, std::numeric_limits<T>::quiet_NaN(), result);
}
return result;
}
Expand Down
61 changes: 22 additions & 39 deletions shared/libebm/compute/avx512f_ebm/avx512f_32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,6 @@ inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept;

struct alignas(k_cAlignment) Avx512f_32_Int final {
friend Avx512f_32_Float;
friend inline Avx512f_32_Float IfEqual(const Avx512f_32_Int& cmp1,
const Avx512f_32_Int& cmp2,
const Avx512f_32_Float& trueVal,
const Avx512f_32_Float& falseVal) noexcept;

using T = uint32_t;
using TPack = __m512i;
Expand Down Expand Up @@ -111,6 +107,10 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
return Avx512f_32_Int(_mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0));
}

friend inline __mmask16 operator==(const Avx512f_32_Int& left, const Avx512f_32_Int& right) noexcept {
return _mm512_cmpeq_epi32_mask(left.m_data, right.m_data);
}

inline Avx512f_32_Int operator+(const Avx512f_32_Int& other) const noexcept {
return Avx512f_32_Int(_mm512_add_epi32(m_data, other.m_data));
}
Expand Down Expand Up @@ -273,6 +273,14 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
return Avx512f_32_Float(val) / other;
}

friend inline __mmask16 operator==(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
return _mm512_cmp_ps_mask(left.m_data, right.m_data, _CMP_EQ_OQ);
}

friend inline __mmask16 operator<(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
return _mm512_cmp_ps_mask(left.m_data, right.m_data, _CMP_LT_OQ);
}

friend inline __mmask16 operator<=(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
return _mm512_cmp_ps_mask(left.m_data, right.m_data, _CMP_LE_OQ);
}
Expand Down Expand Up @@ -583,14 +591,6 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
func(15, a0[15], a1[15], a2[15], a3[15], a4[15]);
}

friend inline Avx512f_32_Float IfLess(const Avx512f_32_Float& cmp1,
const Avx512f_32_Float& cmp2,
const Avx512f_32_Float& trueVal,
const Avx512f_32_Float& falseVal) noexcept {
const __mmask16 mask = _mm512_cmp_ps_mask(cmp1.m_data, cmp2.m_data, _CMP_LT_OQ);
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
}

friend inline Avx512f_32_Float IfThenElse(
const __mmask16& cmp, const Avx512f_32_Float& trueVal, const Avx512f_32_Float& falseVal) noexcept {
return Avx512f_32_Float(_mm512_mask_blend_ps(cmp, falseVal.m_data, trueVal.m_data));
Expand All @@ -602,14 +602,6 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
}


friend inline Avx512f_32_Float IfEqual(const Avx512f_32_Float& cmp1,
const Avx512f_32_Float& cmp2,
const Avx512f_32_Float& trueVal,
const Avx512f_32_Float& falseVal) noexcept {
const __mmask16 mask = _mm512_cmp_ps_mask(cmp1.m_data, cmp2.m_data, _CMP_EQ_OQ);
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
}

friend inline Avx512f_32_Float IfNaN(
const Avx512f_32_Float& cmp, const Avx512f_32_Float& trueVal, const Avx512f_32_Float& falseVal) noexcept {
// rely on the fact that a == a can only be false if a is a NaN
Expand All @@ -618,15 +610,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
// use an AND with _mm256_and_si256 to select just the NaN bits, then compare to zero with
// _mm256_cmpeq_epi32, but that has an overall latency of 2 and a throughput of 0.83333, which is lower
// throughput, so experiment with this
return IfEqual(cmp, cmp, falseVal, trueVal);
}

friend inline Avx512f_32_Float IfEqual(const Avx512f_32_Int& cmp1,
const Avx512f_32_Int& cmp2,
const Avx512f_32_Float& trueVal,
const Avx512f_32_Float& falseVal) noexcept {
const __mmask16 mask = _mm512_cmpeq_epi32_mask(cmp1.m_data, cmp2.m_data);
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
return IfThenElse(cmp == cmp, falseVal, trueVal);
}

static inline Avx512f_32_Int ReinterpretInt(const Avx512f_32_Float& val) noexcept {
Expand Down Expand Up @@ -713,20 +697,20 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
#endif // EXP_INT_SIMD
Avx512f_32_Float result = Avx512f_32_Float(_mm512_castsi512_ps(retInt));
if(bSpecialCaseZero) {
result = IfEqual(0.0, val, 1.0, result);
result = IfThenElse(0.0 == val, 1.0, result);
}
if(bOverflowPossible) {
if(bNegateInput) {
result = IfLess(val, static_cast<T>(-k_expOverflowPoint), std::numeric_limits<T>::infinity(), result);
result = IfThenElse(val < static_cast<T>(-k_expOverflowPoint), std::numeric_limits<T>::infinity(), result);
} else {
result = IfLess(static_cast<T>(k_expOverflowPoint), val, std::numeric_limits<T>::infinity(), result);
result = IfThenElse(static_cast<T>(k_expOverflowPoint) < val, std::numeric_limits<T>::infinity(), result);
}
}
if(bUnderflowPossible) {
if(bNegateInput) {
result = IfLess(static_cast<T>(-k_expUnderflowPoint), val, 0.0, result);
result = IfThenElse(static_cast<T>(-k_expUnderflowPoint) < val, 0.0, result);
} else {
result = IfLess(val, static_cast<T>(k_expUnderflowPoint), 0.0, result);
result = IfThenElse(val < static_cast<T>(k_expUnderflowPoint), 0.0, result);
}
}
if(bNaNPossible) {
Expand Down Expand Up @@ -770,13 +754,13 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
Avx512f_32_Float result = Avx512f_32_Float(_mm512_cvtepi32_ps(retInt));
if(bNaNPossible) {
if(bPositiveInfinityPossible) {
result = IfLess(val, std::numeric_limits<T>::infinity(), result, val);
result = IfThenElse(val < std::numeric_limits<T>::infinity(), result, val);
} else {
result = IfNaN(val, val, result);
}
} else {
if(bPositiveInfinityPossible) {
result = IfEqual(std::numeric_limits<T>::infinity(), val, val, result);
result = IfThenElse(std::numeric_limits<T>::infinity() == val, val, result);
}
}
if(bNegateOutput) {
Expand All @@ -785,13 +769,12 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
result = FusedMultiplyAdd(result, k_logMultiple, addLogSchraudolphTerm);
}
if(bZeroPossible) {
result = IfLess(val,
std::numeric_limits<T>::min(),
result = IfThenElse(val < std::numeric_limits<T>::min(),
bNegateOutput ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(),
result);
}
if(bNegativePossible) {
result = IfLess(val, T{0}, std::numeric_limits<T>::quiet_NaN(), result);
result = IfThenElse(val < T{0}, std::numeric_limits<T>::quiet_NaN(), result);
}
return result;
}
Expand Down
40 changes: 13 additions & 27 deletions shared/libebm/compute/cpu_ebm/cpu_64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ inline Cpu_64_Float Log(const Cpu_64_Float& val) noexcept;

struct Cpu_64_Int final {
friend Cpu_64_Float;
friend inline Cpu_64_Float IfEqual(const Cpu_64_Int& cmp1,
const Cpu_64_Int& cmp2,
const Cpu_64_Float& trueVal,
const Cpu_64_Float& falseVal) noexcept;
friend inline Cpu_64_Float IfThenElse(
const Cpu_64_Int& cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept;
friend inline Cpu_64_Float IfAdd(
Expand Down Expand Up @@ -91,6 +87,10 @@ struct Cpu_64_Int final {

inline Cpu_64_Int operator~() const noexcept { return Cpu_64_Int(~m_data); }

friend inline Cpu_64_Int operator==(const Cpu_64_Int& left, const Cpu_64_Int& right) noexcept {
return left.m_data == right.m_data ? Cpu_64_Int{static_cast<uint64_t>(int64_t{-1})} : Cpu_64_Int{0};
}

inline Cpu_64_Int operator+(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data + other.m_data); }

inline Cpu_64_Int operator-(const Cpu_64_Int& other) const noexcept { return Cpu_64_Int(m_data - other.m_data); }
Expand Down Expand Up @@ -222,9 +222,16 @@ struct Cpu_64_Float final {
return Cpu_64_Float(val) / other;
}

friend inline Cpu_64_Int operator==(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
return left.m_data == right.m_data ? Cpu_64_Int{static_cast<uint64_t>(int64_t{-1})} : Cpu_64_Int{0};
}

friend inline Cpu_64_Int operator<(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
return left.m_data < right.m_data ? Cpu_64_Int{static_cast<uint64_t>(int64_t{-1})} : Cpu_64_Int{0};
}

friend inline Cpu_64_Int operator<=(const Cpu_64_Float& left, const Cpu_64_Float& right) noexcept {
// use all bits of an equally wide datatype so that we can negate it with ~, or AND/OR it
return left.m_data <= right.m_data ? Cpu_64_Int{static_cast<uint64_t>(int64_t{-1})} : Cpu_64_Int{0};
return left.m_data <= right.m_data ? Cpu_64_Int{static_cast<uint64_t>(int64_t{-1})} : Cpu_64_Int{0};
}

inline static Cpu_64_Float Load(const T* const a) noexcept { return Cpu_64_Float(*a); }
Expand All @@ -248,13 +255,6 @@ struct Cpu_64_Float final {
func(0, (args.m_data)...);
}

friend inline Cpu_64_Float IfLess(const Cpu_64_Float& cmp1,
const Cpu_64_Float& cmp2,
const Cpu_64_Float& trueVal,
const Cpu_64_Float& falseVal) noexcept {
return cmp1.m_data < cmp2.m_data ? trueVal : falseVal;
}

friend inline Cpu_64_Float IfThenElse(
const Cpu_64_Int& cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
return cmp.m_data ? trueVal : falseVal;
Expand All @@ -265,25 +265,11 @@ struct Cpu_64_Float final {
return cmp.m_data ? base + addend : base;
}

friend inline Cpu_64_Float IfEqual(const Cpu_64_Float& cmp1,
const Cpu_64_Float& cmp2,
const Cpu_64_Float& trueVal,
const Cpu_64_Float& falseVal) noexcept {
return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
}

friend inline Cpu_64_Float IfNaN(
const Cpu_64_Float& cmp, const Cpu_64_Float& trueVal, const Cpu_64_Float& falseVal) noexcept {
return std::isnan(cmp.m_data) ? trueVal : falseVal;
}

friend inline Cpu_64_Float IfEqual(const Cpu_64_Int& cmp1,
const Cpu_64_Int& cmp2,
const Cpu_64_Float& trueVal,
const Cpu_64_Float& falseVal) noexcept {
return cmp1.m_data == cmp2.m_data ? trueVal : falseVal;
}

static inline Cpu_64_Int ReinterpretInt(const Cpu_64_Float& val) noexcept {
typename Cpu_64_Int::T mem;
memcpy(&mem, &val.m_data, sizeof(T));
Expand Down
Loading

0 comments on commit f64b70f

Please sign in to comment.