tebdol

Simulation of ultracold atoms in optical lattices
git clone https://miroslavurbanek.com/git/tebdol.git
Log | Files | Refs | README

commit ea1b5401f5594b2ccdd875c860fd3859adc7da16
parent e07a73fb30c01b3040c0e12095e5192fc70b0cbe
Author: Miroslav Urbanek <mu@miroslavurbanek.com>
Date:   Wed, 28 Jun 2017 15:02:07 +0200

Add alternative SVD functions

Allow to select among the functions ZGESVD, ZGESDD, and ZGESVDX to
calculate a matrix SVD. A default action is to call ZGESDD, and fall
back to ZGESVD if it fails.

Diffstat:
tebdol/array.lisp | 207++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
tebdol/blas.lisp | 45++++++++++++++++++++++++++++++++++++++++++++-
2 files changed, 196 insertions(+), 56 deletions(-)

diff --git a/tebdol/array.lisp b/tebdol/array.lisp @@ -1,5 +1,5 @@ (defpackage :array - (:use :common-lisp :sb-ext :sb-alien :util :blas) + (:use :common-lisp :sb-sys :sb-ext :sb-alien :util :blas) (:export :array-copy :array-scalar-/ :array-conjugate @@ -7,7 +7,8 @@ :array-permute :array-contract :array-decompose - :array-addf)) + :array-addf + :*svd-driver*)) (in-package :array) @@ -90,7 +91,7 @@ (if (eql pb :left) (setf transa "n" lda 'm) (setf transa "t" lda 'k)) - `(sb-sys:with-pinned-objects + `(with-pinned-objects ((array-storage-vector a) (array-storage-vector b) (array-storage-vector c) @@ -158,72 +159,168 @@ ;; array decomposition +(defun array-decompose-zgesvd (m n a s u vt) + (let* ((min (min m n)) + (rwork (make-double-array (* 5 min)))) + (with-pinned-objects + ((array-storage-vector a) + (array-storage-vector s) + (array-storage-vector u) + (array-storage-vector vt) + (array-storage-vector rwork)) + (let ((aa (blas-array-alien a)) + (as (double-array-alien s)) + (au (blas-array-alien u)) + (avt (blas-array-alien vt)) + (arwork (double-array-alien rwork))) + (flet ((f (work lwork) + (with-pinned-objects ((array-storage-vector work)) + (let ((info + (nth-value + 1 + (zgesvd + "S" "S" m n aa m as au m avt min + (blas-array-alien work) lwork arwork)))) + (unless (zerop info) + (if (< info 0) + (error "Illegal value of parameter ~A in ZGESVD." + (- info)) + (error "ZGESVD did not converge (INFO = ~A)." + info))))))) + (let ((work (make-blas-array 1))) + (f work -1) + (let ((lwork (floor (realpart (aref work 0))))) + (f (make-blas-array lwork) lwork)))))))) + +(defun array-decompose-zgesdd (m n a s u vt) + (let* ((min (min m n)) + (max (max m n)) + (rwork (make-double-array + (* min (max (+ (* 5 min) 5) (+ (* 2 max) (* 2 min) 1))))) + (iwork (make-integer-array (* 8 min)))) + (with-pinned-objects + ((array-storage-vector a) + (array-storage-vector s) + (array-storage-vector u) + (array-storage-vector vt) + (array-storage-vector rwork) + (array-storage-vector iwork)) + (let ((aa (blas-array-alien a)) + (as (double-array-alien s)) + (au (blas-array-alien u)) + (avt (blas-array-alien vt)) + (arwork (double-array-alien rwork)) + (aiwork (integer-array-alien iwork))) + (flet ((f (work lwork) + (with-pinned-objects ((array-storage-vector work)) + (let ((info + (nth-value + 1 + (zgesdd + "S" m n aa m as au m avt min + (blas-array-alien work) lwork arwork aiwork)))) + (unless (zerop info) + (if (< info 0) + (error "Illegal value of parameter ~A in ZGESDD." + (- info)) + (error "ZGESDD did not converge (INFO = ~A)." + info))))))) + (let ((work (make-blas-array 1))) + (f work -1) + (let ((lwork (floor (realpart (aref work 0))))) + (f (make-blas-array lwork) lwork)))))))) + +(defun array-decompose-zgesvdx (m n a s u vt) + (let* ((min (min m n)) + (rwork (make-double-array (* 17 min min))) + (iwork (make-integer-array (* 12 min)))) + (with-pinned-objects + ((array-storage-vector a) + (array-storage-vector s) + (array-storage-vector u) + (array-storage-vector vt) + (array-storage-vector rwork) + (array-storage-vector iwork)) + (let ((aa (blas-array-alien a)) + (as (double-array-alien s)) + (au (blas-array-alien u)) + (avt (blas-array-alien vt)) + (arwork (double-array-alien rwork)) + (aiwork (integer-array-alien iwork))) + (flet ((f (work lwork) + (with-pinned-objects ((array-storage-vector work)) + (let ((info + (nth-value + 2 + (zgesvdx + "V" "V" "A" m n aa m 0d0 0d0 0 0 as au m avt min + (blas-array-alien work) lwork arwork aiwork)))) + (unless (zerop info) + (if (< info 0) + (error + "Illegal value of parameter ~A in ZGESVDX." + (- info)) + (if (eql info (1+ (* 2 n))) + (error + "Internal error occured in DBDSVDX/ZGESVDX.") + (error + "ZGESVDX did not converge (INFO = ~A)." + info)))))))) + (let ((work (make-blas-array 1))) + (f work -1) + (let ((lwork (floor (realpart (aref work 0))))) + (f (make-blas-array lwork) lwork)))))))) + +(defun array-decompose-fallback (m n a s u vt) + (handler-case (array-decompose-zgesdd m n (array-copy a) s u vt) + (error (condition) + (format *error-output* "ZGESDD error: ~A~%Using ZGESVD.~%" condition) + (array-decompose-zgesvd m n a s u vt)))) + +(defvar *svd-driver* :fallback) + (defun array-decompose (array indices) - ;; do not destroy the original array - (setf array (array-copy array)) (if (numberp indices) (setf indices (list indices))) - (let ((r (array-rank array))) - ;; permute indices if they are not in the correct order + ;; don't destroy the original array + (let* ((a (array-copy array)) + (r (array-rank a))) + + ;; permute indices if they aren't in the correct order (unless (and (eql (indices-position r indices) :left) (apply #'< indices)) - (let ((nl (loop for i below (length indices) collect i))) - (setf array (array-permute array (make-indices-permutation r indices nl))) - (setf indices nl))) - - (let* ((d (array-dimensions array)) - (dl (loop for i in indices collect (nth i d))) - (dr (loop for i below r unless (member i indices) collect (nth i d))) - (m (reduce #'* dr)) - (n (reduce #'* dl)) + (let ((l (loop for i below (length indices) collect i))) + (setf a (array-permute a (make-indices-permutation r indices l))) + (setf indices l))) + + (let* ((d (array-dimensions a)) + (dl (loop + for i in indices + collect (nth i d))) + (dr (loop + for i below r + unless (member i indices) + collect (nth i d))) + (m (apply #'* dr)) + (n (apply #'* dl)) (min (min m n)) - (max (max m n)) (s (make-double-array min)) (u (make-blas-array (append (list min) dr))) - (vt (make-blas-array (append dl (list min)))) - (work (make-blas-array 1)) - (lwork -1) - (rwork (make-double-array - (* min (max (+ (* 5 min) 5) - (+ (* 2 max) (* 2 min) 1))))) - (iwork (make-integer-array (* 8 min)))) - - ;; svd - (sb-sys:with-pinned-objects - ((array-storage-vector array) - (array-storage-vector s) - (array-storage-vector u) - (array-storage-vector vt) - (array-storage-vector rwork) - (array-storage-vector iwork)) - (let ((ba (blas-array-alien array)) - (bs (blas-array-alien s)) - (bu (blas-array-alien u)) - (bvt (blas-array-alien vt)) - (brwork (blas-array-alien rwork)) - (biwork (blas-array-alien iwork))) - (macrolet ((f () - `(sb-sys:with-pinned-objects ((array-storage-vector work)) - (let ((info - (nth-value - 1 - (zgesdd "s" m n ba m bs bu m bvt min (blas-array-alien work) - lwork brwork biwork)))) - (unless (zerop info) - (if (< info 0) - (error "Illegal value of parameter ~A in ZGESDD." (- info)) - (error "ZGESDD did not converge (INFO = ~A)." info))))))) - (f) - (setf lwork (floor (realpart (aref work 0)))) - (setf work (make-blas-array lwork)) - (f)))) + (vt (make-blas-array (append dl (list min))))) + (funcall + (ecase *svd-driver* + (:zgesvd #'array-decompose-zgesvd) + (:zgesdd #'array-decompose-zgesdd) + (:zgesvdx #'array-decompose-zgesvdx) + (:fallback #'array-decompose-fallback)) + m n a s u vt) (values vt s u)))) (defun array-addf (x y) (unless (= (array-total-size x) (array-total-size y)) (error "Array dimensions do not match.")) - (sb-sys:with-pinned-objects + (with-pinned-objects ((array-storage-vector x) (array-storage-vector y) (array-storage-vector *blas-alien-1*)) diff --git a/tebdol/blas.lisp b/tebdol/blas.lisp @@ -16,7 +16,9 @@ :dlamch :zaxpy :zgemm + :zgesvd :zgesdd + :zgesvdx :zheevr)) (in-package :blas) @@ -53,7 +55,7 @@ (defun double-array-alien (array) (sap-alien (vector-sap (array-storage-vector array)) (* double))) -(declaim (inline dlamch zaxpy zgemm zgesdd zheevr)) +(declaim (inline dlamch zaxpy zgemm zgesvd zgesdd zgesvdx zheevr)) (define-alien-routine ("dlamch_" dlamch) double (cmach c-string)) @@ -81,6 +83,23 @@ (c (* double)) (ldc int :copy)) +(define-alien-routine ("zgesvd_" zgesvd) void + (jobu c-string) + (jobvt c-string) + (m int :copy) + (n int :copy) + (a (* double)) + (lda int :copy) + (s (* double)) + (u (* double)) + (ldu int :copy) + (vt (* double)) + (ldvt int :copy) + (work (* double)) + (lwork int :copy) + (rwork (* double)) + (info int :out)) + (define-alien-routine ("zgesdd_" zgesdd) void (jobz c-string) (m int :copy) @@ -98,6 +117,30 @@ (iwork (* double)) (info int :out)) +(define-alien-routine ("zgesvdx_" zgesvdx) void + (jobu c-string) + (jobvt c-string) + (range c-string) + (m int :copy) + (n int :copy) + (a (* double)) + (lda int :copy) + (vl double :copy) + (vu double :copy) + (il int :copy) + (iu int :copy) + (ns int :out) + (s (* double)) + (u (* double)) + (ldu int :copy) + (vt (* double)) + (ldvt int :copy) + (work (* double)) + (lwork int :copy) + (rwork (* double)) + (iwork (* int)) + (info int :out)) + (define-alien-routine ("zheevr_" zheevr) void (jobz c-string) (range c-string)