tebdol

Simulation of ultracold atoms in optical lattices
git clone https://miroslavurbanek.com/git/tebdol.git
Log | Files | Refs | README

array.lisp (10736B)


      1 (defpackage :array
      2   (:use :common-lisp :sb-sys :sb-ext :sb-alien :util :blas)
      3   (:export :array-copy
      4 	   :array-scalar-/
      5 	   :array-conjugate
      6 	   :array-norm
      7 	   :array-permute
      8 	   :array-contract
      9 	   :array-decompose
     10 	   :array-addf
     11 	   :*svd-driver*))
     12 
     13 (in-package :array)
     14 
     15 (defun array-copy (array)
     16   (let ((new-array (make-blas-array (array-dimensions array))))
     17     (dotimes (i (array-total-size array) new-array)
     18       (setf (row-major-aref new-array i) (row-major-aref array i)))))
     19 
     20 (defmacro define-unary-element-wise-operation (name function)
     21   `(defun ,name (array &rest args)
     22      (let ((new-array (make-blas-array (array-dimensions array))))
     23        (dotimes (i (array-total-size array) new-array)
     24 	 (setf (row-major-aref new-array i)
     25 	       (apply (function ,function)
     26 		      (row-major-aref array i)
     27 		      args))))))
     28 
     29 (define-unary-element-wise-operation array-scalar-/ /)
     30 (define-unary-element-wise-operation array-conjugate conjugate)
     31 
     32 (defun array-norm (array)
     33   (let ((acc 0))
     34     (dotimes (i (array-total-size array) (sqrt acc))
     35       (incf acc (+ (expt (realpart (row-major-aref array i)) 2)
     36                    (expt (imagpart (row-major-aref array i)) 2))))))
     37 
     38 ;; compiler can produce optimized code for array operations
     39 
     40 ;; array permutation generator
     41 
     42 (defmacro array-permute-macro (array permutation)
     43   (let* ((a (gensym))
     44 	 (b (gensym))
     45 	 (rank (length permutation))
     46 	 (mi (loop for i below rank collect (gensym)))
     47 	 (md (loop for i below rank collect (gensym)))
     48 	 (mip (list-permute mi permutation))
     49 	 (mdp (list-permute md permutation)))
     50     (labels ((nest (&optional (i 0))
     51     	       (if (eql i rank)
     52     		   `(setf (aref ,b ,@mip) (aref ,a ,@mi))
     53 		   `(dotimes (,(nth i mi) ,(nth i md)
     54 			      ,@(when (eql i 0) `(,b)))
     55 		      ,(nest (1+ i))))))
     56       `(let* ((,a ,array)
     57 	      ,@(loop for i below rank
     58 		   collect `(,(nth i md) (array-dimension ,a ,i)))
     59 	      (,b (make-array (list ,@mdp) :element-type 'blas-float)))
     60 	 (declare (type (simple-array blas-float) ,a ,b))
     61 	 ,(nest)))))
     62 
     63 ;; generate array permutation functions dynamically
     64 
     65 (let ((cache (make-hash-table :test 'equal)))
     66   (defun array-permute (array permutation)
     67     (funcall (or (gethash permutation cache)
     68 		 (setf (gethash permutation cache)
     69 		       (symbol-function
     70 			(compile (intern (format nil "ARRAY-PERMUTE-~{~D~^-~}" permutation))
     71 				 `(lambda (a) (array-permute-macro a ,permutation))))))
     72 	     array)))
     73 
     74 ;;; contraction function generator
     75 
     76 (defun make-array-contract-fn (ra ia rb ib)
     77   (labels ((perm (a rx ix s)
     78 	     `(setf ,a (array-permute-macro ,a ,(make-indices-permutation rx ix s))))
     79 	   (syms (rx ix)
     80 	     (let* ((d (loop for i below rx for g = (gensym) collect g))
     81 		    (k (loop for i in ix collect (nth i d)))
     82 		    (l (loop for i below rx unless (member i ix) collect (nth i d))))
     83 	       (values d k l)))
     84 	   (alien (a)
     85 	     `(sap-alien (vector-sap (array-storage-vector ,a)) (* double)))
     86 	   (contract (pa pb)
     87 	     (let (transa transb lda ldb)
     88 	       (if (eql pa :left)
     89 		   (setf transb "t" ldb 'n)
     90 		   (setf transb "n" ldb 'k))
     91 	       (if (eql pb :left)
     92 		   (setf transa "n" lda 'm)
     93 		   (setf transa "t" lda 'k))
     94 	       `(with-pinned-objects
     95 		    ((array-storage-vector a)
     96 		     (array-storage-vector b)
     97 		     (array-storage-vector c)
     98 		     (array-storage-vector *blas-alien-0*)
     99 		     (array-storage-vector *blas-alien-1*))
    100 		  (zgemm ,transa ,transb m n k ,(alien *blas-alien-1*) ,(alien 'b) ,lda
    101 			 ,(alien 'a) ,ldb ,(alien '*blas-alien-0*) ,(alien 'c) m)
    102 		  c))))
    103 
    104     (multiple-value-bind (da ka la) (syms ra ia)
    105       (multiple-value-bind (db kb lb) (syms rb ib)
    106 	`(lambda (a b)
    107 	   (declare (type (simple-array blas-float) a b))
    108 	   (destructuring-bind (,@da) (array-dimensions a)
    109 	     (destructuring-bind (,@db) (array-dimensions b)
    110 	       (unless (and ,@(loop for i in ka for j in kb collect `(= ,i ,j)))
    111 		 (error "Array dimensions do not match."))
    112 	       (let ((k (* ,@ka))
    113 		     (n (* ,@la))
    114 		     (m (* ,@lb))
    115 		     (c (make-array (list ,@la ,@lb) :element-type 'blas-float)))
    116 		 (declare (type (simple-array blas-float) c))
    117 		 ,@(let ((pa (indices-position ra ia))
    118 			 (pb (indices-position rb ib))
    119 			 (sa (indices-shape ia))
    120 			 (sb (indices-shape ib)))
    121 			(cond ((and pa pb)
    122 			       (if (equal sa sb)
    123 				   `(,(contract pa pb))
    124 				   `((if (< (array-total-size a) (array-total-size b))
    125 					 (progn
    126 					   ,(perm 'a ra ia sb)
    127 					   ,(contract :left pb))
    128 					 (progn
    129 					   ,(perm 'b rb ib sa)
    130 					   ,(contract pa :left))))))
    131 			      ((and pa (not pb))
    132 			       `(,(perm 'b rb ib sa) ,(contract pa :left)))
    133 			      ((and (not pa) pb)
    134 			       `(,(perm 'a ra ia sb) ,(contract :left pb)))
    135 			      (t
    136 			       (let ((s (loop for i below (length ia) collect i)))
    137 				 `(,(perm 'a ra ia s)
    138 				   ,(perm 'b rb ib s)
    139 				   ,(contract :left :left))))))))))))))
    140 
    141 ;;; generate array contraction functions dynamically
    142 
    143 (let ((cache (make-hash-table :test 'equal)))
    144   (defun array-contract (a ia b ib)
    145     (unless (listp ia)
    146       (setf ia (list ia)))
    147     (unless (listp ib)
    148       (setf ib (list ib)))
    149     (let ((key (list (array-rank a) ia (array-rank b) ib)))
    150       (funcall (or (gethash key cache)
    151 		   (setf (gethash key cache)
    152 			 (symbol-function
    153 			  (compile
    154 			   (intern
    155 			    (format nil "ARRAY-CONTRACT-~{R~D-~{~D~^-~}-R~D-~{~D~^-~}~}" key))
    156 			   (apply #'make-array-contract-fn key)))))
    157 	       a
    158 	       b))))
    159 
    160 ;; array decomposition
    161 
    162 (defun array-decompose-zgesvd (m n a s u vt)
    163   (let* ((min (min m n))
    164 	 (rwork (make-double-array (* 5 min))))
    165     (with-pinned-objects
    166 	((array-storage-vector a)
    167 	 (array-storage-vector s)
    168 	 (array-storage-vector u)
    169 	 (array-storage-vector vt)
    170 	 (array-storage-vector rwork))
    171       (let ((aa (blas-array-alien a))
    172 	    (as (double-array-alien s))
    173 	    (au (blas-array-alien u))
    174 	    (avt (blas-array-alien vt))
    175 	    (arwork (double-array-alien rwork)))
    176 	(flet ((f (work lwork)
    177 		 (with-pinned-objects ((array-storage-vector work))
    178 		   (let ((info
    179 			  (nth-value
    180 			   1
    181 			   (zgesvd
    182 			    "S" "S" m n aa m as au m avt min
    183 			    (blas-array-alien work) lwork arwork))))
    184 		     (unless (zerop info)
    185 		       (if (< info 0)
    186 			   (error "Illegal value of parameter ~A in ZGESVD."
    187 				  (- info))
    188 			   (error "ZGESVD did not converge (INFO = ~A)."
    189 				  info)))))))
    190 	  (let ((work (make-blas-array 1)))
    191 	    (f work -1)
    192 	    (let ((lwork (floor (realpart (aref work 0)))))
    193 	      (f (make-blas-array lwork) lwork))))))))
    194 
    195 (defun array-decompose-zgesdd (m n a s u vt)
    196   (let* ((min (min m n))
    197 	 (max (max m n))
    198 	 (rwork (make-double-array
    199 		 (* min (max (+ (* 5 min) 5) (+ (* 2 max) (* 2 min) 1)))))
    200 	 (iwork (make-integer-array (* 8 min))))
    201     (with-pinned-objects
    202 	((array-storage-vector a)
    203 	 (array-storage-vector s)
    204 	 (array-storage-vector u)
    205 	 (array-storage-vector vt)
    206 	 (array-storage-vector rwork)
    207 	 (array-storage-vector iwork))
    208       (let ((aa (blas-array-alien a))
    209 	    (as (double-array-alien s))
    210 	    (au (blas-array-alien u))
    211 	    (avt (blas-array-alien vt))
    212 	    (arwork (double-array-alien rwork))
    213 	    (aiwork (integer-array-alien iwork)))
    214 	(flet ((f (work lwork)
    215 		 (with-pinned-objects ((array-storage-vector work))
    216 		   (let ((info
    217 			  (nth-value
    218 			   1
    219 			   (zgesdd
    220 			    "S" m n aa m as au m avt min
    221 			    (blas-array-alien work) lwork arwork aiwork))))
    222 		     (unless (zerop info)
    223 		       (if (< info 0)
    224 			   (error "Illegal value of parameter ~A in ZGESDD."
    225 				  (- info))
    226 			   (error "ZGESDD did not converge (INFO = ~A)."
    227 				  info)))))))
    228 	  (let ((work (make-blas-array 1)))
    229 	    (f work -1)
    230 	    (let ((lwork (floor (realpart (aref work 0)))))
    231 	      (f (make-blas-array lwork) lwork))))))))
    232 
    233 (defun array-decompose-zgesvdx (m n a s u vt)
    234   (let* ((min (min m n))
    235 	 (rwork (make-double-array (* 17 min min)))
    236 	 (iwork (make-integer-array (* 12 min))))
    237     (with-pinned-objects
    238 	((array-storage-vector a)
    239 	 (array-storage-vector s)
    240 	 (array-storage-vector u)
    241 	 (array-storage-vector vt)
    242 	 (array-storage-vector rwork)
    243 	 (array-storage-vector iwork))
    244       (let ((aa (blas-array-alien a))
    245 	    (as (double-array-alien s))
    246 	    (au (blas-array-alien u))
    247 	    (avt (blas-array-alien vt))
    248 	    (arwork (double-array-alien rwork))
    249 	    (aiwork (integer-array-alien iwork)))
    250 	(flet ((f (work lwork)
    251 		 (with-pinned-objects ((array-storage-vector work))
    252 		   (let ((info
    253 			  (nth-value
    254 			   2
    255 			   (zgesvdx
    256 			    "V" "V" "A" m n aa m 0d0 0d0 0 0 as au m avt min
    257 			    (blas-array-alien work) lwork arwork aiwork))))
    258 		     (unless (zerop info)
    259 		       (if (< info 0)
    260 			   (error
    261 			    "Illegal value of parameter ~A in ZGESVDX."
    262 			    (- info))
    263 			   (if (eql info (1+ (* 2 n)))
    264 			       (error
    265 				"Internal error occured in DBDSVDX/ZGESVDX.")
    266 			       (error
    267 				"ZGESVDX did not converge (INFO = ~A)."
    268 				info))))))))
    269 	  (let ((work (make-blas-array 1)))
    270 	    (f work -1)
    271 	    (let ((lwork (floor (realpart (aref work 0)))))
    272 	      (f (make-blas-array lwork) lwork))))))))
    273 
    274 (defun array-decompose-fallback (m n a s u vt)
    275   (handler-case (array-decompose-zgesdd m n (array-copy a) s u vt)
    276     (error (condition)
    277       (format *error-output* "ZGESDD error: ~A~%Using ZGESVD.~%" condition)
    278       (array-decompose-zgesvd m n a s u vt))))
    279 
    280 (defvar *svd-driver* :fallback)
    281 
    282 (defun array-decompose (array indices)
    283   (if (numberp indices)
    284       (setf indices (list indices)))
    285   ;; don't destroy the original array
    286   (let* ((a (array-copy array))
    287 	 (r (array-rank a)))
    288 
    289     ;; permute indices if they aren't in the correct order
    290     (unless (and (eql (indices-position r indices) :left)
    291 		 (apply #'< indices))
    292       (let ((l (loop for i below (length indices) collect i)))
    293 	(setf a (array-permute a (make-indices-permutation r indices l)))
    294 	(setf indices l)))
    295 
    296     (let* ((d (array-dimensions a))
    297 	   (dl (loop
    298 		  for i in indices
    299 		  collect (nth i d)))
    300 	   (dr (loop
    301 		  for i below r
    302 		  unless (member i indices)
    303 		  collect (nth i d)))
    304 	   (m (apply #'* dr))
    305 	   (n (apply #'* dl))
    306 	   (min (min m n))
    307 	   (s (make-double-array min))
    308 	   (u (make-blas-array (append (list min) dr)))
    309 	   (vt (make-blas-array (append dl (list min)))))
    310 
    311       (funcall
    312        (ecase *svd-driver*
    313 	 (:zgesvd #'array-decompose-zgesvd)
    314 	 (:zgesdd #'array-decompose-zgesdd)
    315 	 (:zgesvdx #'array-decompose-zgesvdx)
    316 	 (:fallback #'array-decompose-fallback))
    317        m n a s u vt)
    318       (values vt s u))))
    319 
    320 (defun array-addf (x y)
    321   (unless (= (array-total-size x) (array-total-size y))
    322     (error "Array dimensions do not match."))
    323   (with-pinned-objects
    324       ((array-storage-vector x)
    325        (array-storage-vector y)
    326        (array-storage-vector *blas-alien-1*))
    327     (zaxpy (array-total-size x) (blas-array-alien *blas-alien-1*)
    328 	   (blas-array-alien y) 1 (blas-array-alien x) 1)))