@@ -2978,77 +2978,53 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
2978
2978
float32x4_t sumv0 = vdupq_n_f32 (0.0f );
2979
2979
float32x4_t sumv1 = vdupq_n_f32 (0.0f );
2980
2980
2981
- float summs = 0.0f ;
2981
+ float summs0 = 0.0f ;
2982
+ float summs1 = 0.0f ;
2982
2983
2983
- for (int i = 0 ; i < nb ; i += 2 ) {
2984
+ for (int i = 0 ; i < nb ; ++ i ) {
2984
2985
const block_q4_3 * restrict x0_0 = & x [2 * (i + 0 ) + 0 ];
2985
2986
const block_q4_3 * restrict x0_1 = & x [2 * (i + 0 ) + 1 ];
2986
- const block_q4_3 * restrict x1_0 = & x [2 * (i + 1 ) + 0 ];
2987
- const block_q4_3 * restrict x1_1 = & x [2 * (i + 1 ) + 1 ];
2988
2987
2989
2988
const block_q8_0 * restrict y0 = & y [i + 0 ];
2990
- const block_q8_0 * restrict y1 = & y [i + 1 ];
2991
2989
2992
- summs += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 + GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
2993
- summs += GGML_FP16_TO_FP32 (x1_0 -> m ) * y1 -> s0 + GGML_FP16_TO_FP32 (x1_1 -> m ) * y1 -> s1 ;
2994
-
2995
- const uint8x16_t m4b = vdupq_n_u8 (0xf );
2996
-
2997
- const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
2998
- const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
2999
- const float x1_0d = GGML_FP16_TO_FP32 (x1_0 -> d );
3000
- const float x1_1d = GGML_FP16_TO_FP32 (x1_1 -> d );
2990
+ summs0 += GGML_FP16_TO_FP32 (x0_0 -> m ) * y0 -> s0 ;
2991
+ summs1 += GGML_FP16_TO_FP32 (x0_1 -> m ) * y0 -> s1 ;
3001
2992
3002
2993
const uint8x16_t v0_0 = vcombine_u8 (vld1_u8 (x0_0 -> qs ), vld1_u8 (x0_1 -> qs ));
3003
- const uint8x16_t v0_1 = vcombine_u8 (vld1_u8 (x1_0 -> qs ), vld1_u8 (x1_1 -> qs ));
3004
2994
3005
2995
// 4-bit -> 8-bit
3006
- const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , m4b ));
2996
+ const int8x16_t v0_0l = vreinterpretq_s8_u8 (vandq_u8 (v0_0 , vdupq_n_u8 ( 0xf ) ));
3007
2997
const int8x16_t v0_0h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_0 , 4 ));
3008
- const int8x16_t v0_1l = vreinterpretq_s8_u8 (vandq_u8 (v0_1 , m4b ));
3009
- const int8x16_t v0_1h = vreinterpretq_s8_u8 (vshrq_n_u8 (v0_1 , 4 ));
3010
2998
3011
2999
// interleave
3012
3000
const int8x16_t v0_0lz = vzip1q_s8 (v0_0l , v0_0h );
3013
3001
const int8x16_t v0_0hz = vzip2q_s8 (v0_0l , v0_0h );
3014
- const int8x16_t v0_1lz = vzip1q_s8 (v0_1l , v0_1h );
3015
- const int8x16_t v0_1hz = vzip2q_s8 (v0_1l , v0_1h );
3016
3002
3017
3003
// load y
3018
3004
const int8x16_t v1_0l = vld1q_s8 (y0 -> qs );
3019
3005
const int8x16_t v1_0h = vld1q_s8 (y0 -> qs + 16 );
3020
- const int8x16_t v1_1l = vld1q_s8 (y1 -> qs );
3021
- const int8x16_t v1_1h = vld1q_s8 (y1 -> qs + 16 );
3006
+
3007
+ const float x0_0d = GGML_FP16_TO_FP32 (x0_0 -> d );
3008
+ const float x0_1d = GGML_FP16_TO_FP32 (x0_1 -> d );
3022
3009
3023
3010
#if defined(__ARM_FEATURE_DOTPROD )
3024
3011
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0lz , v1_0l )), x0_0d * y0 -> d );
3025
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
3026
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1lz , v1_1l )), x1_0d * y1 -> d );
3027
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_1hz , v1_1h )), x1_1d * y1 -> d );
3012
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (vdotq_s32 (vdupq_n_s32 (0 ), v0_0hz , v1_0h )), x0_1d * y0 -> d );
3028
3013
#else
3029
3014
const int16x8_t pl0l = vmull_s8 (vget_low_s8 (v0_0lz ), vget_low_s8 (v1_0l ));
3030
3015
const int16x8_t pl0h = vmull_s8 (vget_high_s8 (v0_0lz ), vget_high_s8 (v1_0l ));
3031
3016
const int16x8_t ph0l = vmull_s8 (vget_low_s8 (v0_0hz ), vget_low_s8 (v1_0h ));
3032
3017
const int16x8_t ph0h = vmull_s8 (vget_high_s8 (v0_0hz ), vget_high_s8 (v1_0h ));
3033
3018
3034
- const int16x8_t pl1l = vmull_s8 (vget_low_s8 (v0_1lz ), vget_low_s8 (v1_1l ));
3035
- const int16x8_t pl1h = vmull_s8 (vget_high_s8 (v0_1lz ), vget_high_s8 (v1_1l ));
3036
- const int16x8_t ph1l = vmull_s8 (vget_low_s8 (v0_1hz ), vget_low_s8 (v1_1h ));
3037
- const int16x8_t ph1h = vmull_s8 (vget_high_s8 (v0_1hz ), vget_high_s8 (v1_1h ));
3038
-
3039
3019
const int32x4_t pl0 = vaddq_s32 (vpaddlq_s16 (pl0l ), vpaddlq_s16 (pl0h ));
3040
3020
const int32x4_t ph0 = vaddq_s32 (vpaddlq_s16 (ph0l ), vpaddlq_s16 (ph0h ));
3041
- const int32x4_t pl1 = vaddq_s32 (vpaddlq_s16 (pl1l ), vpaddlq_s16 (pl1h ));
3042
- const int32x4_t ph1 = vaddq_s32 (vpaddlq_s16 (ph1l ), vpaddlq_s16 (ph1h ));
3043
3021
3044
3022
sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (pl0 ), x0_0d * y0 -> d );
3045
- sumv0 = vmlaq_n_f32 (sumv0 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
3046
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (pl1 ), x1_0d * y1 -> d );
3047
- sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph1 ), x1_1d * y1 -> d );
3023
+ sumv1 = vmlaq_n_f32 (sumv1 , vcvtq_f32_s32 (ph0 ), x0_1d * y0 -> d );
3048
3024
#endif
3049
3025
}
3050
3026
3051
- sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs ;
3027
+ sumf = vaddvq_f32 (vaddq_f32 (sumv0 , sumv1 )) + summs0 + summs1 ;
3052
3028
#else
3053
3029
// scalar
3054
3030
for (int i = 0 ; i < nb ; i ++ ) {
0 commit comments