@@ -2255,6 +2255,77 @@ end
2255
2255
2256
2256
typed_hvcat (:: Type{T} , rows:: Tuple{Vararg{Int}} , as... ) where T = typed_hvncat (T, rows_to_dimshape (rows), true , as... )
2257
2257
2258
+ # A fast version of hvcat for the case where we have static size information of xs
2259
+ # and the number of rows is known at compile time -- we can eliminate all the runtime
2260
+ # size checks. For cases that static size information is not beneficial, we fall back to
2261
+ # the general hvcat/typed_hvcat methods.
2262
+ @generated function typed_hvcat_static (:: Type{T} , :: Val{rows} , xs:: Number... ) where {T<: Number , rows}
2263
+ nr = length (rows)
2264
+ nc = rows[1 ]
2265
+ for i = 2 : nr
2266
+ if nc != rows[i]
2267
+ return quote
2268
+ msg = " row " * string ($ i) * " has mismatched number of columns (expected " * string ($ nc) * " , got " * string ($ rows[$ i]) * " )"
2269
+ throw (DimensionMismatch (msg))
2270
+ end
2271
+ end
2272
+ end
2273
+
2274
+ len = length (xs)
2275
+ if nr* nc != len
2276
+ return quote
2277
+ msg = " argument count " * string ($ len) * " does not match specified shape " * string (($ nr, $ nc))
2278
+ throw (ArgumentError (msg))
2279
+ end
2280
+ end
2281
+
2282
+ if len <= 16
2283
+ # For small array construction, manually unroll the loop for better performance
2284
+ assigns = Expr[]
2285
+ k = 1
2286
+ for i in 1 : nr
2287
+ for j in 1 : nc
2288
+ ex = :(a[$ i, $ j] = xs[$ k])
2289
+ push! (assigns, ex)
2290
+ k += 1
2291
+ end
2292
+ end
2293
+
2294
+ return quote
2295
+ a = Matrix {$T} (undef, $ nr, $ nc)
2296
+ $ (assigns... )
2297
+ return a
2298
+ end
2299
+ end
2300
+
2301
+ quote
2302
+ a = Matrix {$T} (undef, $ nr, $ nc)
2303
+ k = 1
2304
+ @inbounds for i in 1 : $ nr
2305
+ for j in 1 : $ nc
2306
+ a[i,j] = xs[k]
2307
+ k += 1
2308
+ end
2309
+ end
2310
+ a
2311
+ end
2312
+ end
2313
+ @inline function hvcat_static (:: Val{rows} , x:: T , xs:: Vararg{T} ) where {rows, T<: Number }
2314
+ typed_hvcat_static (T, Val {rows} (), x, xs... )
2315
+ end
2316
+ @inline function hvcat_static (:: Val{rows} , xs:: Number... ) where {rows}
2317
+ typed_hvcat_static (promote_typeof (xs... ), Val {rows} (), xs... )
2318
+ end
2319
+ @inline function typed_hvcat_static (:: Type{T} , :: Val{rows} , xs... ) where {T, rows}
2320
+ # fallback to the general case
2321
+ typed_hvcat (T, rows, xs... )
2322
+ end
2323
+ @inline function hvcat_static (:: Val{rows} , xs... ) where {rows}
2324
+ # fallback to the general case
2325
+ hvcat (rows, xs... )
2326
+ end
2327
+
2328
+
2258
2329
# # N-dimensional concatenation ##
2259
2330
2260
2331
"""
@@ -2750,6 +2821,94 @@ end
2750
2821
Ai
2751
2822
end
2752
2823
2824
+ # Static version of hvncat for better performance with scalar numbers
2825
+ # See the comments for hvcat_static for more details.
2826
+ @generated function typed_hvncat_static (:: Type{T} , :: Val{dims} , :: Val{row_first} , xs:: Number... ) where {T<: Number , dims, row_first}
2827
+ for d in dims
2828
+ if d <= 0
2829
+ return quote
2830
+ throw (ArgumentError (" `dims` argument must contain positive integers" ))
2831
+ end
2832
+ end
2833
+ end
2834
+
2835
+ N = length (dims)
2836
+ lengtha = prod (dims)
2837
+ lengthx = length (xs)
2838
+ if lengtha != lengthx
2839
+ return quote
2840
+ msg = " argument count does not match specified shape (expected " * string ($ lengtha) * " , got " * string ($ lengthx) * " )"
2841
+ throw (ArgumentError (msg))
2842
+ end
2843
+ end
2844
+
2845
+ if lengthx <= 16
2846
+ # For small array construction, manually unroll the loop
2847
+ assigns = Expr[]
2848
+ nr, nc = dims[1 ], dims[2 ]
2849
+ na = if N > 2
2850
+ n = 1
2851
+ for d in 3 : N
2852
+ n *= dims[d]
2853
+ end
2854
+ n
2855
+ else
2856
+ 1
2857
+ end
2858
+ nrc = nr * nc
2859
+
2860
+ if row_first
2861
+ k = 1
2862
+ for d in 1 : na
2863
+ dd = nrc * (d - 1 )
2864
+ for i in 1 : nr
2865
+ Ai = dd + i
2866
+ for j in 1 : nc
2867
+ ex = :(A[$ Ai] = xs[$ k])
2868
+ push! (assigns, ex)
2869
+ k += 1
2870
+ Ai += nr
2871
+ end
2872
+ end
2873
+ end
2874
+ else
2875
+ k = 1
2876
+ for i in 1 : lengtha
2877
+ ex = :(A[$ i] = xs[$ k])
2878
+ push! (assigns, ex)
2879
+ k += 1
2880
+ end
2881
+ end
2882
+
2883
+ return quote
2884
+ A = Array {$T, $N} (undef, $ dims... )
2885
+ $ (assigns... )
2886
+ return A
2887
+ end
2888
+ end
2889
+
2890
+ # For larger arrays, use the regular loop
2891
+ quote
2892
+ A = Array {$T, $N} (undef, $ dims... )
2893
+ hvncat_fill! (A, $ row_first, xs)
2894
+ return A
2895
+ end
2896
+ end
2897
+ @inline function hvncat_static (:: Val{dims} , :: Val{row_first} , x:: T , xs:: Vararg{T} ) where {dims, row_first, T<: Number }
2898
+ typed_hvncat_static (T, Val {dims} (), Val {row_first} (), x, xs... )
2899
+ end
2900
+ @inline function hvncat_static (:: Val{dims} , :: Val{row_first} , xs:: Number... ) where {dims, row_first}
2901
+ typed_hvncat_static (promote_typeof (xs... ), Val {dims} (), Val {row_first} (), xs... )
2902
+ end
2903
+ @inline function typed_hvncat_static (:: Type{T} , :: Val{dims} , :: Val{row_first} , xs... ) where {T, dims, row_first}
2904
+ # fallback to the general case
2905
+ typed_hvncat (T, dims, row_first, xs... )
2906
+ end
2907
+ @inline function hvncat_static (:: Val{dims} , :: Val{row_first} , xs... ) where {dims, row_first}
2908
+ # fallback to the general case
2909
+ hvncat (dims, row_first, xs... )
2910
+ end
2911
+
2753
2912
"""
2754
2913
stack(iter; [dims])
2755
2914
0 commit comments