Skip to content

Commit 9ed559f

Browse files
committed
add static size information for hvcat and hvncat
This significantly improves the performance of multi-dimensional array creation via scalar numbers using methods of [a b; c d], [a b;; c d] and typed T[a b; c d], T[a b;; c d] For small numeric array creation(length <= 16), manual loop unroll is used to further minimize the overhead, and it now has zero overhead and is as fast as the array initialization method.
1 parent d0a521f commit 9ed559f

File tree

12 files changed

+987
-6
lines changed

12 files changed

+987
-6
lines changed

base/abstractarray.jl

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,6 +2255,77 @@ end
22552255

22562256
typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T = typed_hvncat(T, rows_to_dimshape(rows), true, as...)
22572257

2258+
# A fast version of hvcat for the case where we have static size information of xs
2259+
# and the number of rows is known at compile time -- we can eliminate all the runtime
2260+
# size checks. For cases that static size information is not beneficial, we fall back to
2261+
# the general hvcat/typed_hvcat methods.
2262+
@generated function typed_hvcat_static(::Type{T}, ::Val{rows}, xs::Number...) where {T<:Number, rows}
2263+
nr = length(rows)
2264+
nc = rows[1]
2265+
for i = 2:nr
2266+
if nc != rows[i]
2267+
return quote
2268+
msg = "row " * string($i) * " has mismatched number of columns (expected " * string($nc) * ", got " * string($rows[$i]) * ")"
2269+
throw(DimensionMismatch(msg))
2270+
end
2271+
end
2272+
end
2273+
2274+
len = length(xs)
2275+
if nr*nc != len
2276+
return quote
2277+
msg = "argument count " * string($len) * " does not match specified shape " * string(($nr, $nc))
2278+
throw(ArgumentError(msg))
2279+
end
2280+
end
2281+
2282+
if len <= 16
2283+
# For small array construction, manually unroll the loop for better performance
2284+
assigns = Expr[]
2285+
k = 1
2286+
for i in 1:nr
2287+
for j in 1:nc
2288+
ex = :(a[$i, $j] = xs[$k])
2289+
push!(assigns, ex)
2290+
k += 1
2291+
end
2292+
end
2293+
2294+
return quote
2295+
a = Matrix{$T}(undef, $nr, $nc)
2296+
$(assigns...)
2297+
return a
2298+
end
2299+
end
2300+
2301+
quote
2302+
a = Matrix{$T}(undef, $nr, $nc)
2303+
k = 1
2304+
@inbounds for i in 1:$nr
2305+
for j in 1:$nc
2306+
a[i,j] = xs[k]
2307+
k += 1
2308+
end
2309+
end
2310+
a
2311+
end
2312+
end
2313+
@inline function hvcat_static(::Val{rows}, x::T, xs::Vararg{T}) where {rows, T<:Number}
2314+
typed_hvcat_static(T, Val{rows}(), x, xs...)
2315+
end
2316+
@inline function hvcat_static(::Val{rows}, xs::Number...) where {rows}
2317+
typed_hvcat_static(promote_typeof(xs...), Val{rows}(), xs...)
2318+
end
2319+
@inline function typed_hvcat_static(::Type{T}, ::Val{rows}, xs...) where {T, rows}
2320+
# fallback to the general case
2321+
typed_hvcat(T, rows, xs...)
2322+
end
2323+
@inline function hvcat_static(::Val{rows}, xs...) where {rows}
2324+
# fallback to the general case
2325+
hvcat(rows, xs...)
2326+
end
2327+
2328+
22582329
## N-dimensional concatenation ##
22592330

22602331
"""
@@ -2750,6 +2821,94 @@ end
27502821
Ai
27512822
end
27522823

2824+
# Static version of hvncat for better performance with scalar numbers
2825+
# See the comments for hvcat_static for more details.
2826+
@generated function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs::Number...) where {T<:Number, dims, row_first}
2827+
for d in dims
2828+
if d <= 0
2829+
return quote
2830+
throw(ArgumentError("`dims` argument must contain positive integers"))
2831+
end
2832+
end
2833+
end
2834+
2835+
N = length(dims)
2836+
lengtha = prod(dims)
2837+
lengthx = length(xs)
2838+
if lengtha != lengthx
2839+
return quote
2840+
msg = "argument count does not match specified shape (expected " * string($lengtha) * ", got " * string($lengthx) * ")"
2841+
throw(ArgumentError(msg))
2842+
end
2843+
end
2844+
2845+
if lengthx <= 16
2846+
# For small array construction, manually unroll the loop
2847+
assigns = Expr[]
2848+
nr, nc = dims[1], dims[2]
2849+
na = if N > 2
2850+
n = 1
2851+
for d in 3:N
2852+
n *= dims[d]
2853+
end
2854+
n
2855+
else
2856+
1
2857+
end
2858+
nrc = nr * nc
2859+
2860+
if row_first
2861+
k = 1
2862+
for d in 1:na
2863+
dd = nrc * (d - 1)
2864+
for i in 1:nr
2865+
Ai = dd + i
2866+
for j in 1:nc
2867+
ex = :(A[$Ai] = xs[$k])
2868+
push!(assigns, ex)
2869+
k += 1
2870+
Ai += nr
2871+
end
2872+
end
2873+
end
2874+
else
2875+
k = 1
2876+
for i in 1:lengtha
2877+
ex = :(A[$i] = xs[$k])
2878+
push!(assigns, ex)
2879+
k += 1
2880+
end
2881+
end
2882+
2883+
return quote
2884+
A = Array{$T, $N}(undef, $dims...)
2885+
$(assigns...)
2886+
return A
2887+
end
2888+
end
2889+
2890+
# For larger arrays, use the regular loop
2891+
quote
2892+
A = Array{$T, $N}(undef, $dims...)
2893+
hvncat_fill!(A, $row_first, xs)
2894+
return A
2895+
end
2896+
end
2897+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, x::T, xs::Vararg{T}) where {dims, row_first, T<:Number}
2898+
typed_hvncat_static(T, Val{dims}(), Val{row_first}(), x, xs...)
2899+
end
2900+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs::Number...) where {dims, row_first}
2901+
typed_hvncat_static(promote_typeof(xs...), Val{dims}(), Val{row_first}(), xs...)
2902+
end
2903+
@inline function typed_hvncat_static(::Type{T}, ::Val{dims}, ::Val{row_first}, xs...) where {T, dims, row_first}
2904+
# fallback to the general case
2905+
typed_hvncat(T, dims, row_first, xs...)
2906+
end
2907+
@inline function hvncat_static(::Val{dims}, ::Val{row_first}, xs...) where {dims, row_first}
2908+
# fallback to the general case
2909+
hvncat(dims, row_first, xs...)
2910+
end
2911+
27532912
"""
27542913
stack(iter; [dims])
27552914

cat.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
当拼接对象为标量数值时,优化了矩阵拼接方式创建数值矩阵的性能。
2+
3+
优化思路:
4+
5+
* 通过在语法解析时保留静态尺寸信息,优化了通过标量数值拼接二维与高维数组时的性能。
6+
* 对于小矩阵场景,引入额外的手动循环展开,进一步提升了性能。
7+
8+
主要结果:
9+
10+
* 对于小矩阵创建来说,使用 `[]` 数组拼接的方式与使用矩阵初始化的方式的性能相当。过去存在较为显著的性能差异。
11+
* 对于大矩阵创建来说,有显著的性能提升,但相比于矩阵初始化的方式来说,依然存在一定的性能差异。
12+
13+
## [x x; y y] 二维矩阵创建
14+
15+
相同数据类型
16+
17+
```julia
18+
using BenchmarkTools
19+
20+
f(x, y) = [x x; y y]
21+
@btime f(x, y) setup=(x = rand(); y = rand())
22+
```
23+
24+
2x2 矩阵
25+
26+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
27+
| ------------- | ---------- | -------- |
28+
| 1.9.3 | 18.021 | 37.258 |
29+
| 1.10.9 | 18.634 | 263.918 |
30+
| 1.10-dev | 19.660 | 19.749 |
31+
32+
6x6 矩阵
33+
34+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
35+
| ------------- | ---------- | -------- |
36+
| 1.9.3 | 30.865 | 144.490 |
37+
| 1.10.9 | 32.856 | 2198 |
38+
| 1.10-dev | 32.278 | 31.363 |
39+
40+
混合数据类型
41+
42+
```julia
43+
g(x, y) = [x x;;; y y]
44+
@btime g(x, y) setup=(x = rand(); y = rand(1:10))
45+
```
46+
47+
2x2 矩阵
48+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
49+
| ------------- | ---------- | -------- |
50+
| 1.9.3 | 17.794 | 72.243 |
51+
| 1.10.9 | 18.587 | 844.879 |
52+
| 1.10-dev | 19.467 | 19.394 |
53+
54+
6x6 矩阵
55+
56+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
57+
| ------------- | ---------- | -------- |
58+
| 1.9.3 | 35.994 | 11284 |
59+
| 1.10.9 | 36.748 | 23835 |
60+
| 1.10-dev | 32.352 | 646.955 |
61+
62+
## [x x;;; y y] 高维数组创建
63+
64+
相同数据类型
65+
66+
```julia
67+
h(x, y) = [x x;;; y y]
68+
@btime h(x, y) setup=(x = rand(); y = rand())
69+
```
70+
71+
1x2x2 矩阵
72+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
73+
| ------------- | ---------- | -------- |
74+
| 1.9.3 | 19.754 | 112.885 |
75+
| 1.10.9 | 20.175 | 108.488 |
76+
| 1.10-dev | 21.437 | 19.700 |
77+
78+
3x3x3 矩阵
79+
80+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
81+
| ------------- | ---------- | -------- |
82+
| 1.9.3 | 43.327 | 445.046 |
83+
| 1.10.9 | 44.832 | 473.566 |
84+
| 1.10-dev | 44.633 | 48.476 |
85+
86+
混合数据类型
87+
88+
```julia
89+
i(x, y) = [x x;;; y y]
90+
@btime i(x, y) setup=(x = rand(); y = rand(1:10))
91+
```
92+
93+
1x2x2 矩阵
94+
95+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
96+
| ------------- | ---------- | -------- |
97+
| 1.9.3 | 19.192 | 147.851 |
98+
| 1.10.9 | 20.126 | 147.297 |
99+
| 1.10-dev | 20.719 | 19.509 |
100+
101+
3x3x3 矩阵
102+
103+
| Julia version | 初始化方法 (ns) | 拼接方法 (ns) |
104+
| ------------- | ---------- | -------- |
105+
| 1.9.3 | 44.267 | 753.900 |
106+
| 1.10.9 | 43.800 | 750.264 |
107+
| 1.10-dev | 43.620 | 363.485 |

0 commit comments

Comments
 (0)