Skip to content

Commit 5fd9924

Browse files
authored
lstsq: return correct array size (#818)
2 parents 42182b0 + 6d90c25 commit 5fd9924

File tree

3 files changed

+48
-14
lines changed

3 files changed

+48
-14
lines changed

doc/specs/stdlib_linalg.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -767,7 +767,7 @@ Result vector `x` returns the approximate solution that minimizes the 2-norm \(
767767

768768
`b`: Shall be a rank-1 or rank-2 array of the same kind as `a`, containing one or more right-hand-side vector(s), each in its leading dimension. It is an `intent(in)` argument.
769769

770-
`x`: Shall be an array of same kind and rank as `b`, containing the solution(s) to the least squares system. It is an `intent(inout)` argument.
770+
`x`: Shall be an array of same kind and rank as `b`, and leading dimension of at least `n`, containing the solution(s) to the least squares system. It is an `intent(inout)` argument.
771771

772772
`real_storage` (optional): Shall be a `real` rank-1 array of the same kind `a`, providing working storage for the solver. It minimum size can be determined with a call to [[stdlib_linalg(module):lstsq_space(interface)]]. It is an `intent(inout)` argument.
773773

src/stdlib_linalg_least_squares.fypp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
8585
pure module subroutine stdlib_linalg_${ri}$_lstsq_space_${ndsuf}$(a,b,lrwork,liwork#{if rt.startswith('c')}#,lcwork#{endif}#)
8686
!> Input matrix a[m,n]
8787
${rt}$, intent(in), target :: a(:,:)
88-
!> Right hand side vector or array, b[n] or b[n,nrhs]
88+
!> Right hand side vector or array, b[m] or b[m,nrhs]
8989
${rt}$, intent(in) :: b${nd}$
9090
!> Size of the working space arrays
9191
integer(ilp), intent(out) :: lrwork,liwork
@@ -111,7 +111,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
111111
!! This function computes the least-squares solution of a linear matrix problem.
112112
!!
113113
!! param: a Input matrix of size [m,n].
114-
!! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
114+
!! param: b Right-hand-side vector of size [m] or matrix of size [m,nrhs].
115115
!! param: cond [optional] Real input threshold indicating that singular values `s_i <= cond*maxval(s)`
116116
!! do not contribute to the matrix rank.
117117
!! param: overwrite_a [optional] Flag indicating if the input matrix can be overwritten.
@@ -121,7 +121,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
121121
!!
122122
!> Input matrix a[m,n]
123123
${rt}$, intent(inout), target :: a(:,:)
124-
!> Right hand side vector or array, b[n] or b[n,nrhs]
124+
!> Right hand side vector or array, b[m] or b[m,nrhs]
125125
${rt}$, intent(in) :: b${nd}$
126126
!> [optional] cutoff for rank evaluation: singular values s(i)<=cond*maxval(s) are considered 0.
127127
real(${rk}$), optional, intent(in) :: cond
@@ -134,9 +134,19 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
134134
!> Result array/matrix x[n] or x[n,nrhs]
135135
${rt}$, allocatable, target :: x${nd}$
136136

137-
! Initialize solution with the shape of the rhs
138-
allocate(x,mold=b)
137+
integer(ilp) :: n,nrhs,ldb
138+
139+
n = size(a,2,kind=ilp)
140+
ldb = size(b,1,kind=ilp)
141+
nrhs = size(b,kind=ilp)/ldb
139142

143+
! Initialize solution with the shape of the rhs
144+
#:if ndsuf=="one"
145+
allocate(x(n))
146+
#:else
147+
allocate(x(n,nrhs))
148+
#:endif
149+
140150
call stdlib_linalg_${ri}$_solve_lstsq_${ndsuf}$(a,b,x,&
141151
cond=cond,overwrite_a=overwrite_a,rank=rank,err=err)
142152

@@ -155,7 +165,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
155165
!!
156166
!! param: a Input matrix of size [m,n].
157167
!! param: b Right-hand-side vector of size [n] or matrix of size [n,nrhs].
158-
!! param: x Solution vector of size [n] or solution matrix of size [n,nrhs].
168+
!! param: x Solution vector of size at [>=n] or solution matrix of size [>=n,nrhs].
159169
!! param: real_storage [optional] Real working space
160170
!! param: int_storage [optional] Integer working space
161171
#:if rt.startswith('c')
@@ -198,7 +208,7 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
198208
integer(ilp) :: m,n,lda,ldb,nrhs,ldx,nrhsx,info,mnmin,mnmax,arank,lrwork,liwork,lcwork
199209
integer(ilp) :: nrs,nis,ncs,nsvd
200210
integer(ilp), pointer :: iwork(:)
201-
logical(lk) :: copy_a
211+
logical(lk) :: copy_a,large_enough_x
202212
real(${rk}$) :: acond,rcond
203213
real(${rk}$), pointer :: rwork(:),singular(:)
204214
${rt}$, pointer :: xmat(:,:),amat(:,:),cwork(:)
@@ -214,8 +224,8 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
214224
mnmin = min(m,n)
215225
mnmax = max(m,n)
216226

217-
if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx/=m) then
218-
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'invalid sizes: a=',[lda,n], &
227+
if (lda<1 .or. n<1 .or. ldb<1 .or. ldb/=m .or. ldx<n) then
228+
err0 = linalg_state_type(this,LINALG_VALUE_ERROR,'insufficient sizes: a=',[lda,n], &
219229
'b=',[ldb,nrhs],' x=',[ldx,nrhsx])
220230
call linalg_error_handling(err0,err)
221231
if (present(rank)) rank = 0
@@ -236,9 +246,19 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
236246
amat => a
237247
endif
238248

239-
! Initialize solution with the rhs
240-
x = b
241-
xmat(1:n,1:nrhs) => x
249+
! If x is large enough to store b, use it as temporary rhs storage.
250+
large_enough_x = ldx>=m
251+
if (large_enough_x) then
252+
xmat(1:ldx,1:nrhs) => x
253+
else
254+
allocate(xmat(m,nrhs))
255+
endif
256+
257+
#:if ndsuf=="one"
258+
xmat(1:m,1) = b
259+
#:else
260+
xmat(1:m,1:nrhs) = b
261+
#:endif
242262

243263
! Singular values array (in decreasing order)
244264
if (present(singvals)) then
@@ -316,7 +336,16 @@ submodule (stdlib_linalg) stdlib_linalg_least_squares
316336
endif
317337

318338
! Process output and return
339+
if (.not.large_enough_x) then
340+
#:if ndsuf=="one"
341+
x(1:n) = xmat(1:n,1)
342+
#:else
343+
x(1:n,1:nrhs) = xmat(1:n,1:nrhs)
344+
#:endif
345+
deallocate(xmat)
346+
endif
319347
if (copy_a) deallocate(amat)
348+
320349
if (present(rank)) rank = arank
321350
if (.not.present(real_storage)) deallocate(rwork)
322351
if (.not.present(int_storage)) deallocate(iwork)

test/linalg/test_linalg_lstsq.fypp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ module test_linalg_least_squares
7171
type(linalg_state_type) :: state
7272
integer(ilp), parameter :: n = 12, m = 3
7373
real :: Arnd(n,m),xrnd(m)
74-
${rt}$ :: xsol(m),x(m),y(n),A(n,m)
74+
${rt}$, allocatable :: x(:)
75+
${rt}$ :: xsol(m),y(n),A(n,m)
7576

7677
! Random coefficient matrix and solution
7778
call random_number(Arnd)
@@ -88,6 +89,10 @@ module test_linalg_least_squares
8889
call check(error,state%ok(),state%print())
8990
if (allocated(error)) return
9091

92+
! Check size
93+
call check(error,size(x)==m)
94+
if (allocated(error)) return
95+
9196
call check(error, all(abs(x-xsol)<1.0e-4_${rk}$), 'data converged')
9297
if (allocated(error)) return
9398

0 commit comments

Comments
 (0)