Skip to content

Commit

Permalink
Add Kahan summation to tinyBLAS Q8/Q4 on ARM
Browse files Browse the repository at this point in the history
  • Loading branch information
jart committed Apr 30, 2024
1 parent 540dc16 commit 2af3b88
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
1 change: 0 additions & 1 deletion llama.cpp/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1882,7 +1882,6 @@ void ggml_once(atomic_uint * once, void init(void)) {

inline static float ggml_silu_f32(float x) {
// SiLU is the favored by LLaMA, Mistral, Phi, Rocket, etc.
// fprintf(stderr, "silu(%g)\n", x);
return x/(1.f + expf(-x));
}

Expand Down
56 changes: 32 additions & 24 deletions llamafile/tinyblas_cpu.inc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,15 @@ inline U madder(T a, T b, U c, U *e) {
return t;
}

#ifdef __ARM_NEON
inline float32x4_t badder(float32x4_t a, float b, float32x4_t c, float32x4_t *e) {
float32x4_t y = sub(vmulq_n_f32(a, b), *e);
float32x4_t t = add(c, y);
*e = sub(sub(t, c), y);
return t;
}
#endif

#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
Expand Down Expand Up @@ -609,47 +618,47 @@ class tinyBLAS_Q0_ARM {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
gemm<3, 3, false>(m0, m, n0, n);
break;
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
gemm<3, 2, false>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
gemm<2, 3, false>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
gemm<2, 2, true>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
gemm<3, 1, true>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
gemm<1, 3, true>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
gemm<2, 1, true>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
gemm<1, 2, true>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
gemm<1, 1, true>(m0, m, n0, n);
break;
default:
return;
Expand All @@ -660,7 +669,7 @@ class tinyBLAS_Q0_ARM {
mnpack(m0, m, np, n);
}

template <int RM, int RN>
template <int RM, int RN, int KAHAN>
NOINLINE void gemm(long m0, long m, long n0, long n) {
long ytiles = RM > 1 ? (m - m0) / RM : 1;
long xtiles = RN > 1 ? (n - n0) / RN : 1;
Expand All @@ -674,18 +683,21 @@ class tinyBLAS_Q0_ARM {
long ii = m0 + job / xtiles * RM;
long jj = n0 + job % xtiles * RN;
float32x4_t Cv[RN][RM] = {};
float32x4_t Ce[RN][RM] = {};
for (int l = 0; l < k; ++l)
for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i)
Cv[j][i] = vmlaq_n_f32(
Cv[j][i],
vcvtq_f32_s32(vdotq_s32(vdotq_s32(vdupq_n_s32(0),
load_lo(INDEX(A, lda, ii + i, l)),
load_lo(INDEX(B, ldb, jj + j, l))),
load_hi(INDEX(A, lda, ii + i, l)),
load_hi(INDEX(B, ldb, jj + j, l)))),
(unhalf(INDEX(A, lda, ii + i, l)->d) *
unhalf(INDEX(B, ldb, jj + j, l)->d)));
for (int i = 0; i < RM; ++i) {
float32x4_t a = vcvtq_f32_s32(vdotq_s32(
vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),
load_lo(INDEX(B, ldb, jj + j, l))),
load_hi(INDEX(A, lda, ii + i, l)), load_hi(INDEX(B, ldb, jj + j, l))));
float b = unhalf(INDEX(A, lda, ii + i, l)->d) *
unhalf(INDEX(B, ldb, jj + j, l)->d);
if (KAHAN)
Cv[j][i] = badder(a, b, Cv[j][i], &Ce[j][i]);
else
Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);
}
for (int j = 0; j < RN; ++j)
for (int i = 0; i < RM; ++i)
*INDEX(C, ldc, jj + j, ii + i) = hsum(Cv[j][i]);
Expand Down Expand Up @@ -981,8 +993,6 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (n < 4)
return false;
if (k % 4)
return false;
tinyBLAS<0, 4, float32x4_t, float32x4_t, float, float, float> tb{
Expand Down Expand Up @@ -1078,8 +1088,6 @@ bool llamafile_sgemm(long m, long n, long k, const void *A, long lda, const void
return false;
}
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 2)
return false;
if (k % 8)
return false;
if (Btype != GGML_TYPE_F16)
Expand Down

0 comments on commit 2af3b88

Please sign in to comment.