commit ea1b5401f5594b2ccdd875c860fd3859adc7da16
parent e07a73fb30c01b3040c0e12095e5192fc70b0cbe
Author: Miroslav Urbanek <mu@miroslavurbanek.com>
Date: Wed, 28 Jun 2017 15:02:07 +0200
Add alternative SVD functions
Allow to select among the functions ZGESVD, ZGESDD, and ZGESVDX to
calculate a matrix SVD. A default action is to call ZGESDD, and fall
back to ZGESVD if it fails.
Diffstat:
tebdol/array.lisp | | | 207 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------- |
tebdol/blas.lisp | | | 45 | ++++++++++++++++++++++++++++++++++++++++++++- |
2 files changed, 196 insertions(+), 56 deletions(-)
diff --git a/tebdol/array.lisp b/tebdol/array.lisp
@@ -1,5 +1,5 @@
(defpackage :array
- (:use :common-lisp :sb-ext :sb-alien :util :blas)
+ (:use :common-lisp :sb-sys :sb-ext :sb-alien :util :blas)
(:export :array-copy
:array-scalar-/
:array-conjugate
@@ -7,7 +7,8 @@
:array-permute
:array-contract
:array-decompose
- :array-addf))
+ :array-addf
+ :*svd-driver*))
(in-package :array)
@@ -90,7 +91,7 @@
(if (eql pb :left)
(setf transa "n" lda 'm)
(setf transa "t" lda 'k))
- `(sb-sys:with-pinned-objects
+ `(with-pinned-objects
((array-storage-vector a)
(array-storage-vector b)
(array-storage-vector c)
@@ -158,72 +159,168 @@
;; array decomposition
+(defun array-decompose-zgesvd (m n a s u vt)
+ (let* ((min (min m n))
+ (rwork (make-double-array (* 5 min))))
+ (with-pinned-objects
+ ((array-storage-vector a)
+ (array-storage-vector s)
+ (array-storage-vector u)
+ (array-storage-vector vt)
+ (array-storage-vector rwork))
+ (let ((aa (blas-array-alien a))
+ (as (double-array-alien s))
+ (au (blas-array-alien u))
+ (avt (blas-array-alien vt))
+ (arwork (double-array-alien rwork)))
+ (flet ((f (work lwork)
+ (with-pinned-objects ((array-storage-vector work))
+ (let ((info
+ (nth-value
+ 1
+ (zgesvd
+ "S" "S" m n aa m as au m avt min
+ (blas-array-alien work) lwork arwork))))
+ (unless (zerop info)
+ (if (< info 0)
+ (error "Illegal value of parameter ~A in ZGESVD."
+ (- info))
+ (error "ZGESVD did not converge (INFO = ~A)."
+ info)))))))
+ (let ((work (make-blas-array 1)))
+ (f work -1)
+ (let ((lwork (floor (realpart (aref work 0)))))
+ (f (make-blas-array lwork) lwork))))))))
+
+(defun array-decompose-zgesdd (m n a s u vt)
+ (let* ((min (min m n))
+ (max (max m n))
+ (rwork (make-double-array
+ (* min (max (+ (* 5 min) 5) (+ (* 2 max) (* 2 min) 1)))))
+ (iwork (make-integer-array (* 8 min))))
+ (with-pinned-objects
+ ((array-storage-vector a)
+ (array-storage-vector s)
+ (array-storage-vector u)
+ (array-storage-vector vt)
+ (array-storage-vector rwork)
+ (array-storage-vector iwork))
+ (let ((aa (blas-array-alien a))
+ (as (double-array-alien s))
+ (au (blas-array-alien u))
+ (avt (blas-array-alien vt))
+ (arwork (double-array-alien rwork))
+ (aiwork (integer-array-alien iwork)))
+ (flet ((f (work lwork)
+ (with-pinned-objects ((array-storage-vector work))
+ (let ((info
+ (nth-value
+ 1
+ (zgesdd
+ "S" m n aa m as au m avt min
+ (blas-array-alien work) lwork arwork aiwork))))
+ (unless (zerop info)
+ (if (< info 0)
+ (error "Illegal value of parameter ~A in ZGESDD."
+ (- info))
+ (error "ZGESDD did not converge (INFO = ~A)."
+ info)))))))
+ (let ((work (make-blas-array 1)))
+ (f work -1)
+ (let ((lwork (floor (realpart (aref work 0)))))
+ (f (make-blas-array lwork) lwork))))))))
+
+(defun array-decompose-zgesvdx (m n a s u vt)
+ (let* ((min (min m n))
+ (rwork (make-double-array (* 17 min min)))
+ (iwork (make-integer-array (* 12 min))))
+ (with-pinned-objects
+ ((array-storage-vector a)
+ (array-storage-vector s)
+ (array-storage-vector u)
+ (array-storage-vector vt)
+ (array-storage-vector rwork)
+ (array-storage-vector iwork))
+ (let ((aa (blas-array-alien a))
+ (as (double-array-alien s))
+ (au (blas-array-alien u))
+ (avt (blas-array-alien vt))
+ (arwork (double-array-alien rwork))
+ (aiwork (integer-array-alien iwork)))
+ (flet ((f (work lwork)
+ (with-pinned-objects ((array-storage-vector work))
+ (let ((info
+ (nth-value
+ 2
+ (zgesvdx
+ "V" "V" "A" m n aa m 0d0 0d0 0 0 as au m avt min
+ (blas-array-alien work) lwork arwork aiwork))))
+ (unless (zerop info)
+ (if (< info 0)
+ (error
+ "Illegal value of parameter ~A in ZGESVDX."
+ (- info))
+ (if (eql info (1+ (* 2 n)))
+ (error
+ "Internal error occured in DBDSVDX/ZGESVDX.")
+ (error
+ "ZGESVDX did not converge (INFO = ~A)."
+ info))))))))
+ (let ((work (make-blas-array 1)))
+ (f work -1)
+ (let ((lwork (floor (realpart (aref work 0)))))
+ (f (make-blas-array lwork) lwork))))))))
+
+(defun array-decompose-fallback (m n a s u vt)
+ (handler-case (array-decompose-zgesdd m n (array-copy a) s u vt)
+ (error (condition)
+ (format *error-output* "ZGESDD error: ~A~%Using ZGESVD.~%" condition)
+ (array-decompose-zgesvd m n a s u vt))))
+
+(defvar *svd-driver* :fallback)
+
(defun array-decompose (array indices)
- ;; do not destroy the original array
- (setf array (array-copy array))
(if (numberp indices)
(setf indices (list indices)))
- (let ((r (array-rank array)))
- ;; permute indices if they are not in the correct order
+ ;; don't destroy the original array
+ (let* ((a (array-copy array))
+ (r (array-rank a)))
+
+ ;; permute indices if they aren't in the correct order
(unless (and (eql (indices-position r indices) :left)
(apply #'< indices))
- (let ((nl (loop for i below (length indices) collect i)))
- (setf array (array-permute array (make-indices-permutation r indices nl)))
- (setf indices nl)))
-
- (let* ((d (array-dimensions array))
- (dl (loop for i in indices collect (nth i d)))
- (dr (loop for i below r unless (member i indices) collect (nth i d)))
- (m (reduce #'* dr))
- (n (reduce #'* dl))
+ (let ((l (loop for i below (length indices) collect i)))
+ (setf a (array-permute a (make-indices-permutation r indices l)))
+ (setf indices l)))
+
+ (let* ((d (array-dimensions a))
+ (dl (loop
+ for i in indices
+ collect (nth i d)))
+ (dr (loop
+ for i below r
+ unless (member i indices)
+ collect (nth i d)))
+ (m (apply #'* dr))
+ (n (apply #'* dl))
(min (min m n))
- (max (max m n))
(s (make-double-array min))
(u (make-blas-array (append (list min) dr)))
- (vt (make-blas-array (append dl (list min))))
- (work (make-blas-array 1))
- (lwork -1)
- (rwork (make-double-array
- (* min (max (+ (* 5 min) 5)
- (+ (* 2 max) (* 2 min) 1)))))
- (iwork (make-integer-array (* 8 min))))
-
- ;; svd
- (sb-sys:with-pinned-objects
- ((array-storage-vector array)
- (array-storage-vector s)
- (array-storage-vector u)
- (array-storage-vector vt)
- (array-storage-vector rwork)
- (array-storage-vector iwork))
- (let ((ba (blas-array-alien array))
- (bs (blas-array-alien s))
- (bu (blas-array-alien u))
- (bvt (blas-array-alien vt))
- (brwork (blas-array-alien rwork))
- (biwork (blas-array-alien iwork)))
- (macrolet ((f ()
- `(sb-sys:with-pinned-objects ((array-storage-vector work))
- (let ((info
- (nth-value
- 1
- (zgesdd "s" m n ba m bs bu m bvt min (blas-array-alien work)
- lwork brwork biwork))))
- (unless (zerop info)
- (if (< info 0)
- (error "Illegal value of parameter ~A in ZGESDD." (- info))
- (error "ZGESDD did not converge (INFO = ~A)." info)))))))
- (f)
- (setf lwork (floor (realpart (aref work 0))))
- (setf work (make-blas-array lwork))
- (f))))
+ (vt (make-blas-array (append dl (list min)))))
+ (funcall
+ (ecase *svd-driver*
+ (:zgesvd #'array-decompose-zgesvd)
+ (:zgesdd #'array-decompose-zgesdd)
+ (:zgesvdx #'array-decompose-zgesvdx)
+ (:fallback #'array-decompose-fallback))
+ m n a s u vt)
(values vt s u))))
(defun array-addf (x y)
(unless (= (array-total-size x) (array-total-size y))
(error "Array dimensions do not match."))
- (sb-sys:with-pinned-objects
+ (with-pinned-objects
((array-storage-vector x)
(array-storage-vector y)
(array-storage-vector *blas-alien-1*))
diff --git a/tebdol/blas.lisp b/tebdol/blas.lisp
@@ -16,7 +16,9 @@
:dlamch
:zaxpy
:zgemm
+ :zgesvd
:zgesdd
+ :zgesvdx
:zheevr))
(in-package :blas)
@@ -53,7 +55,7 @@
(defun double-array-alien (array)
(sap-alien (vector-sap (array-storage-vector array)) (* double)))
-(declaim (inline dlamch zaxpy zgemm zgesdd zheevr))
+(declaim (inline dlamch zaxpy zgemm zgesvd zgesdd zgesvdx zheevr))
(define-alien-routine ("dlamch_" dlamch) double
(cmach c-string))
@@ -81,6 +83,23 @@
(c (* double))
(ldc int :copy))
+(define-alien-routine ("zgesvd_" zgesvd) void
+ (jobu c-string)
+ (jobvt c-string)
+ (m int :copy)
+ (n int :copy)
+ (a (* double))
+ (lda int :copy)
+ (s (* double))
+ (u (* double))
+ (ldu int :copy)
+ (vt (* double))
+ (ldvt int :copy)
+ (work (* double))
+ (lwork int :copy)
+ (rwork (* double))
+ (info int :out))
+
(define-alien-routine ("zgesdd_" zgesdd) void
(jobz c-string)
(m int :copy)
@@ -98,6 +117,30 @@
(iwork (* double))
(info int :out))
+(define-alien-routine ("zgesvdx_" zgesvdx) void
+ (jobu c-string)
+ (jobvt c-string)
+ (range c-string)
+ (m int :copy)
+ (n int :copy)
+ (a (* double))
+ (lda int :copy)
+ (vl double :copy)
+ (vu double :copy)
+ (il int :copy)
+ (iu int :copy)
+ (ns int :out)
+ (s (* double))
+ (u (* double))
+ (ldu int :copy)
+ (vt (* double))
+ (ldvt int :copy)
+ (work (* double))
+ (lwork int :copy)
+ (rwork (* double))
+ (iwork (* int))
+ (info int :out))
+
(define-alien-routine ("zheevr_" zheevr) void
(jobz c-string)
(range c-string)