Skip to content

Commit 25b41a3

Browse files
committed
ggml : slight improvement of Q4_3 - no need for loop unrolling
1 parent a465988 commit 25b41a3

File tree

1 file changed

+12
-36
lines changed

1 file changed

+12
-36
lines changed

ggml.c

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2978,77 +2978,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
29782978
float32x4_t sumv0 = vdupq_n_f32(0.0f);
29792979
float32x4_t sumv1 = vdupq_n_f32(0.0f);
29802980

2981-
float summs = 0.0f;
2981+
float summs0 = 0.0f;
2982+
float summs1 = 0.0f;
29822983

2983-
for (int i = 0; i < nb; i += 2) {
2984+
for (int i = 0; i < nb; ++i) {
29842985
const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0];
29852986
const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1];
2986-
const block_q4_3 * restrict x1_0 = &x[2*(i + 1) + 0];
2987-
const block_q4_3 * restrict x1_1 = &x[2*(i + 1) + 1];
29882987

29892988
const block_q8_0 * restrict y0 = &y[i + 0];
2990-
const block_q8_0 * restrict y1 = &y[i + 1];
29912989

2992-
summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
2993-
summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1;
2994-
2995-
const uint8x16_t m4b = vdupq_n_u8(0xf);
2996-
2997-
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
2998-
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
2999-
const float x1_0d = GGML_FP16_TO_FP32(x1_0->d);
3000-
const float x1_1d = GGML_FP16_TO_FP32(x1_1->d);
2990+
summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0;
2991+
summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1;
30012992

30022993
const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs));
3003-
const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs));
30042994

30052995
// 4-bit -> 8-bit
3006-
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
2996+
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0xf)));
30072997
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
3008-
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
3009-
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
30102998

30112999
// interleave
30123000
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
30133001
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
3014-
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
3015-
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
30163002

30173003
// load y
30183004
const int8x16_t v1_0l = vld1q_s8(y0->qs);
30193005
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
3020-
const int8x16_t v1_1l = vld1q_s8(y1->qs);
3021-
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
3006+
3007+
const float x0_0d = GGML_FP16_TO_FP32(x0_0->d);
3008+
const float x0_1d = GGML_FP16_TO_FP32(x0_1->d);
30223009

30233010
#if defined(__ARM_FEATURE_DOTPROD)
30243011
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d);
3025-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
3026-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), x1_0d*y1->d);
3027-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), x1_1d*y1->d);
3012+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d);
30283013
#else
30293014
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l));
30303015
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l));
30313016
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h));
30323017
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h));
30333018

3034-
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l));
3035-
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l));
3036-
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h));
3037-
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h));
3038-
30393019
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
30403020
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
3041-
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
3042-
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
30433021

30443022
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d);
3045-
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(ph0), x0_1d*y0->d);
3046-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(pl1), x1_0d*y1->d);
3047-
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph1), x1_1d*y1->d);
3023+
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d);
30483024
#endif
30493025
}
30503026

3051-
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs;
3027+
sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1;
30523028
#else
30533029
// scalar
30543030
for (int i = 0; i < nb; i++) {

0 commit comments

Comments
 (0)