@@ -4,6 +4,41 @@ using LinearAlgebra: AbstractTriangular, StridedMaybeAdjOrTransMat, UpperOrLower
44 RealHermSymComplexHerm, HermOrSym, checksquare, sym_uplo, wrap
55using Random: rand!
66
7+ _fix_size (M, nrow, ncol) = M
8+
9+ # An immutable fixed size wrapper for matrices to work around
10+ # the performance issue caused by https://github.com/JuliaLang/julia/issues/60409
11+ # This is more-of-less a stripped down version of FixedSizeArrays
12+ # which we can't easily use without pulling that into the standard library.
13+ struct _FixedSizeMatrix{Trans,R}
14+ ref:: R
15+ nrow:: Int
16+ ncol:: Int
17+ function _FixedSizeMatrix {Trans} (ref:: R , nrow, ncol) where {Trans,R}
18+ new {Trans,R} (ref, nrow, ncol)
19+ end
20+ end
21+ @inline Base. getindex (A:: _FixedSizeMatrix{'N'} , i, j) =
22+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[]
23+ @inline Base. setindex! (A:: _FixedSizeMatrix{'N'} , v, i, j) =
24+ @inbounds Core. memoryrefnew (A. ref, A. nrow * (j - 1 ) + i, false )[] = v
25+
26+ @inline Base. getindex (A:: _FixedSizeMatrix{'T'} , i, j) =
27+ @inbounds transpose (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
28+ @inline Base. setindex! (A:: _FixedSizeMatrix{'T'} , v, i, j) =
29+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = transpose (v)
30+
31+ @inline Base. getindex (A:: _FixedSizeMatrix{'C'} , i, j) =
32+ @inbounds adjoint (Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[])
33+ @inline Base. setindex! (A:: _FixedSizeMatrix{'C'} , v, i, j) =
34+ @inbounds Core. memoryrefnew (A. ref, A. ncol * (i - 1 ) + j, false )[] = adjoint (v)
35+
36+ @inline _fix_size (A:: Matrix , nrow, ncol) = _FixedSizeMatrix {'N'} (A. ref, nrow, ncol)
37+ @inline _fix_size (A:: Transpose{<:Any,<:Matrix} , nrow, ncol) =
38+ _FixedSizeMatrix {'T'} (A. parent. ref, nrow, ncol)
39+ @inline _fix_size (A:: Adjoint{<:Any,<:Matrix} , nrow, ncol) =
40+ _FixedSizeMatrix {'C'} (A. parent. ref, nrow, ncol)
41+
742const tilebufsize = 10800 # Approximately 32k/3
843
944# In matrix-vector multiplication, the correct orientation of the vector is assumed.
@@ -69,52 +104,99 @@ Base.@constprop :aggressive function spdensemul!(C, tA, tB, A, B, alpha, beta)
69104 T = eltype (C)
70105 _mul! (rangefun, diagop, odiagop, C, A, wrap (B, tB), T (alpha), T (beta))
71106 else
72- @stable_muladdmul LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), MulAddMul ( alpha, beta) )
107+ LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
73108 end
74109 return C
75110end
76111
112+ # Slow non-inlined functions for throwing the error without messing up the caller
113+ @noinline function _matmul_size_error (mC, nC, mA, nA, mB, nB, At, Bt)
114+ if At == ' N'
115+ Anames = " first" , " second"
116+ else
117+ Anames = " second" , " first"
118+ end
119+ if Bt == ' N'
120+ Bnames = " first" , " second"
121+ else
122+ Bnames = " second" , " first"
123+ end
124+ nA == mB ||
125+ throw (DimensionMismatch (" $(Anames[2 ]) dimension of A, $nA , does not match the $(Bnames[1 ]) dimension of B, $mB " ))
126+ mA == mC ||
127+ throw (DimensionMismatch (" $(Anames[1 ]) dimension of A, $mA , does not match the first dimension of C, $mC " ))
128+ nB == nC ||
129+ throw (DimensionMismatch (" $(Bnames[2 ]) dimension of B, $nB , does not match the second dimension of C, $nC " ))
130+ # unreachable
131+ throw (DimensionMismatch (" Unknown dimension mismatch" ))
132+ end
133+
134+ @inline function _matmul_size (C, A, B, :: Val{At} , :: Val{Bt} ) where {At,Bt}
135+ mC = size (C, 1 )
136+ nC = size (C, 2 )
137+ mA = size (A, 1 )
138+ nA = size (A, 2 )
139+ mB = size (B, 1 )
140+ nB = size (B, 2 )
141+
142+ _mA, _nA = At == ' N' ? (mA, nA) : (nA, mA)
143+ _mB, _nB = Bt == ' N' ? (mB, nB) : (nB, mB)
144+
145+ if (_nA != _mB) | (_mA != mC) | (_nB != nC)
146+ _matmul_size_error (mC, nC, _mA, _nA, _mB, _nB, At, Bt)
147+ end
148+ return mC, nC, mA, nA, mB, nB
149+ end
150+
151+ @inline _matmul_size_AB (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' N' ))
152+ @inline _matmul_size_AtB (C, A, B) = _matmul_size (C, A, B, Val (' T' ), Val (' N' ))
153+ @inline _matmul_size_ABt (C, A, B) = _matmul_size (C, A, B, Val (' N' ), Val (' T' ))
154+
77155function _spmatmul! (C, A, B, α, β)
78- size (A, 2 ) == size (B, 1 ) ||
79- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of B, $(size (B,1 )) " ))
80- size (A, 1 ) == size (C, 1 ) ||
81- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of C, $(size (C,1 )) " ))
82- size (B, 2 ) == size (C, 2 ) ||
83- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
156+ Cax2 = axes (C, 2 )
157+ Aax2 = axes (A, 2 )
158+ mC, nC, mA, nA, mB, nB = _matmul_size_AB (C, A, B)
84159 nzv = nonzeros (A)
85160 rv = rowvals (A)
86- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
87- for k in axes (C, 2 )
88- @inbounds for col in axes (A,2 )
89- αxj = B[col,k] * α
161+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
162+ if α isa Bool && ! α
163+ return
164+ end
165+ B = _fix_size (B, mB, nB)
166+ C = _fix_size (C, mC, nC)
167+ for k in Cax2
168+ @inbounds for col in Aax2
169+ αxj = α isa Bool ? B[col,k] : B[col,k] * α
90170 for j in nzrange (A, col)
91- C[rv[j], k] += nzv[j]* αxj
171+ rvj = rv[j]
172+ C[rvj, k] = muladd (nzv[j], αxj, C[rvj, k])
92173 end
93174 end
94175 end
95- C
96176end
97177
98178function _At_or_Ac_mul_B! (tfun:: Function , C, A, B, α, β)
99- size (A, 2 ) == size (C, 1 ) ||
100- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the first dimension of C, $(size (C,1 )) " ))
101- size (A, 1 ) == size (B, 1 ) ||
102- throw (DimensionMismatch (" first dimension of A, $(size (A,1 )) , does not match the first dimension of B, $(size (B,1 )) " ))
103- size (B, 2 ) == size (C, 2 ) ||
104- throw (DimensionMismatch (" second dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
179+ Cax2 = axes (C, 2 )
180+ Aax2 = axes (A, 2 )
181+ mC, nC, mA, nA, mB, nB = _matmul_size_AtB (C, A, B)
105182 nzv = nonzeros (A)
106183 rv = rowvals (A)
107- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
108- for k in axes (C, 2 )
109- @inbounds for col in axes (A,2 )
110- tmp = zero (eltype (C))
184+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
185+ if α isa Bool && ! α
186+ return
187+ end
188+ C0 = zero (eltype (C)) # Pre-allocate for BigFloat/BigInt etc
189+ B = _fix_size (B, mB, nB)
190+ C = _fix_size (C, mC, nC)
191+ for k in Cax2
192+ @inbounds for col in Aax2
193+ tmp = C0
111194 for j in nzrange (A, col)
112- tmp += tfun (nzv[j])* B[rv[j],k]
195+ tmp = muladd ( tfun (nzv[j]), B[rv[j], k], tmp)
113196 end
114- C[col,k] += tmp * α
197+ C[col, k] = α isa Bool ? tmp + C[col, k] : muladd (tmp, α, C[col, k])
115198 end
116199 end
117- C
118200end
119201
120202Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix , tA, tB, A:: DenseMatrixUnion , B:: SparseMatrixCSCUnion2 , alpha:: Number , beta:: Number , :: LinearAlgebra.BlasFlag.SyrkHerkGemm )
@@ -132,63 +214,71 @@ Base.@constprop :aggressive generic_matmatmul_wrapper!(C::StridedMatrix, tA, tB,
132214 LinearAlgebra. _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
133215
134216function _spmul! (C:: StridedMatrix , X:: DenseMatrixUnion , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
135- mX, nX = size (X)
136- nX == size (A, 1 ) ||
137- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
138- mX == size (C, 1 ) ||
139- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
140- size (A, 2 ) == size (C, 2 ) ||
141- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
217+ Aax2 = axes (A, 2 )
218+ Xax1 = axes (X, 1 )
219+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
142220 rv = rowvals (A)
143221 nzv = nonzeros (A)
144- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
145- @inbounds for col in axes (A,2 ), k in nzrange (A, col)
146- Aiα = nzv[k] * α
222+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
223+ if α isa Bool && ! α
224+ return
225+ end
226+ C = _fix_size (C, mC, nC)
227+ X = _fix_size (X, mX, nX)
228+ @inbounds for col in Aax2, k in nzrange (A, col)
229+ Aiα = α isa Bool ? nzv[k] : nzv[k] * α
147230 rvk = rv[k]
148- @simd for multivec_row in axes (X,1 )
149- C[multivec_row, col] += X[multivec_row, rvk] * Aiα
231+ @simd for multivec_row in Xax1
232+ C[multivec_row, col] = muladd (X[multivec_row, rvk], Aiα,
233+ C[multivec_row, col])
150234 end
151235 end
152- C
153236end
154237function _spmul! (C:: StridedMatrix , X:: AdjOrTrans{<:Any,<:DenseMatrixUnion} , A:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
155- mX, nX = size (X)
156- nX == size (A, 1 ) ||
157- throw (DimensionMismatch (" second dimension of X, $nX , does not match the first dimension of A, $(size (A,1 )) " ))
158- mX == size (C, 1 ) ||
159- throw (DimensionMismatch (" first dimension of X, $mX , does not match the first dimension of C, $(size (C,1 )) " ))
160- size (A, 2 ) == size (C, 2 ) ||
161- throw (DimensionMismatch (" second dimension of A, $(size (A,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
238+ Xax1 = axes (X, 1 )
239+ Cax2 = axes (C, 2 )
240+ mC, nC, mX, nX, mA, nA = _matmul_size_AB (C, X, A)
162241 rv = rowvals (A)
163242 nzv = nonzeros (A)
164- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
165- for multivec_row in axes (X,1 ), col in axes (C, 2 )
166- @inbounds for k in nzrange (A, col)
167- C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
243+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
244+ if α isa Bool && ! α
245+ return
246+ end
247+ C = _fix_size (C, mC, nC)
248+ X = _fix_size (X, mX, nX)
249+ @inbounds for multivec_row in Xax1, col in Cax2
250+ nzrng = nzrange (A, col)
251+ if isempty (nzrng)
252+ continue
253+ end
254+ tmp = C[multivec_row, col]
255+ for k in nzrng
256+ tmp = muladd (X[multivec_row, rv[k]],
257+ (α isa Bool ? nzv[k] : nzv[k] * α), tmp)
168258 end
259+ C[multivec_row, col] = tmp
169260 end
170- C
171261end
172262
173263function _A_mul_Bt_or_Bc! (tfun:: Function , C:: StridedMatrix , A:: AbstractMatrix , B:: SparseMatrixCSCUnion2 , α:: Number , β:: Number )
174- mA, nA = size (A)
175- nA == size (B, 2 ) ||
176- throw (DimensionMismatch (" second dimension of A, $nA , does not match the second dimension of B, $(size (B,2 )) " ))
177- mA == size (C, 1 ) ||
178- throw (DimensionMismatch (" first dimension of A, $mA , does not match the first dimension of C, $(size (C,1 )) " ))
179- size (B, 1 ) == size (C, 2 ) ||
180- throw (DimensionMismatch (" first dimension of B, $(size (B,2 )) , does not match the second dimension of C, $(size (C,2 )) " ))
264+ Bax2 = axes (B, 2 )
265+ Aax1 = axes (A, 1 )
266+ mC, nC, mA, nA, mB, nB = _matmul_size_ABt (C, A, B)
181267 rv = rowvals (B)
182268 nzv = nonzeros (B)
183- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
184- @inbounds for col in axes (B, 2 ), k in nzrange (B, col)
185- Biα = tfun (nzv[k]) * α
269+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
270+ if α isa Bool && ! α
271+ return
272+ end
273+ C = _fix_size (C, mC, nC)
274+ A = _fix_size (A, mA, nA)
275+ @inbounds for col in Bax2, k in nzrange (B, col)
276+ Biα = α isa Bool ? tfun (nzv[k]) : tfun (nzv[k]) * α
186277 rvk = rv[k]
187- @simd for multivec_col in axes (A, 1 )
188- C[multivec_col, rvk] += A[multivec_col, col] * Biα
278+ @simd for multivec_col in Aax1
279+ C[multivec_col, rvk] = muladd ( A[multivec_col, col], Biα, C[multivec_col, rvk])
189280 end
190281 end
191- C
192282end
193283
194284function * (A:: Diagonal , b:: AbstractSparseVector )
@@ -1243,7 +1333,7 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12431333 rv = rowvals (A)
12441334 nzv = nonzeros (A)
12451335 let z = T (0 ), sumcol= z, αxj= z, aarc= z, α = α
1246- β != one (β) && LinearAlgebra. _rmul_or_fill! (C, β)
1336+ isone (β) || LinearAlgebra. _rmul_or_fill! (C, β)
12471337 @inbounds for k in axes (B,2 )
12481338 for col in axes (B,1 )
12491339 αxj = B[col,k] * α
@@ -1262,7 +1352,6 @@ function _mul!(nzrang::Function, diagop::Function, odiagop::Function, C::Strided
12621352 end
12631353 end
12641354 end
1265- C
12661355end
12671356
12681357# row range up to (and including if excl=false) diagonal
0 commit comments