array.lisp (10736B)
1 (defpackage :array 2 (:use :common-lisp :sb-sys :sb-ext :sb-alien :util :blas) 3 (:export :array-copy 4 :array-scalar-/ 5 :array-conjugate 6 :array-norm 7 :array-permute 8 :array-contract 9 :array-decompose 10 :array-addf 11 :*svd-driver*)) 12 13 (in-package :array) 14 15 (defun array-copy (array) 16 (let ((new-array (make-blas-array (array-dimensions array)))) 17 (dotimes (i (array-total-size array) new-array) 18 (setf (row-major-aref new-array i) (row-major-aref array i))))) 19 20 (defmacro define-unary-element-wise-operation (name function) 21 `(defun ,name (array &rest args) 22 (let ((new-array (make-blas-array (array-dimensions array)))) 23 (dotimes (i (array-total-size array) new-array) 24 (setf (row-major-aref new-array i) 25 (apply (function ,function) 26 (row-major-aref array i) 27 args)))))) 28 29 (define-unary-element-wise-operation array-scalar-/ /) 30 (define-unary-element-wise-operation array-conjugate conjugate) 31 32 (defun array-norm (array) 33 (let ((acc 0)) 34 (dotimes (i (array-total-size array) (sqrt acc)) 35 (incf acc (+ (expt (realpart (row-major-aref array i)) 2) 36 (expt (imagpart (row-major-aref array i)) 2)))))) 37 38 ;; compiler can produce optimized code for array operations 39 40 ;; array permutation generator 41 42 (defmacro array-permute-macro (array permutation) 43 (let* ((a (gensym)) 44 (b (gensym)) 45 (rank (length permutation)) 46 (mi (loop for i below rank collect (gensym))) 47 (md (loop for i below rank collect (gensym))) 48 (mip (list-permute mi permutation)) 49 (mdp (list-permute md permutation))) 50 (labels ((nest (&optional (i 0)) 51 (if (eql i rank) 52 `(setf (aref ,b ,@mip) (aref ,a ,@mi)) 53 `(dotimes (,(nth i mi) ,(nth i md) 54 ,@(when (eql i 0) `(,b))) 55 ,(nest (1+ i)))))) 56 `(let* ((,a ,array) 57 ,@(loop for i below rank 58 collect `(,(nth i md) (array-dimension ,a ,i))) 59 (,b (make-array (list ,@mdp) :element-type 'blas-float))) 60 (declare (type (simple-array blas-float) ,a ,b)) 61 ,(nest))))) 62 63 ;; generate array permutation functions dynamically 64 65 (let ((cache (make-hash-table :test 'equal))) 66 (defun array-permute (array permutation) 67 (funcall (or (gethash permutation cache) 68 (setf (gethash permutation cache) 69 (symbol-function 70 (compile (intern (format nil "ARRAY-PERMUTE-~{~D~^-~}" permutation)) 71 `(lambda (a) (array-permute-macro a ,permutation)))))) 72 array))) 73 74 ;;; contraction function generator 75 76 (defun make-array-contract-fn (ra ia rb ib) 77 (labels ((perm (a rx ix s) 78 `(setf ,a (array-permute-macro ,a ,(make-indices-permutation rx ix s)))) 79 (syms (rx ix) 80 (let* ((d (loop for i below rx for g = (gensym) collect g)) 81 (k (loop for i in ix collect (nth i d))) 82 (l (loop for i below rx unless (member i ix) collect (nth i d)))) 83 (values d k l))) 84 (alien (a) 85 `(sap-alien (vector-sap (array-storage-vector ,a)) (* double))) 86 (contract (pa pb) 87 (let (transa transb lda ldb) 88 (if (eql pa :left) 89 (setf transb "t" ldb 'n) 90 (setf transb "n" ldb 'k)) 91 (if (eql pb :left) 92 (setf transa "n" lda 'm) 93 (setf transa "t" lda 'k)) 94 `(with-pinned-objects 95 ((array-storage-vector a) 96 (array-storage-vector b) 97 (array-storage-vector c) 98 (array-storage-vector *blas-alien-0*) 99 (array-storage-vector *blas-alien-1*)) 100 (zgemm ,transa ,transb m n k ,(alien *blas-alien-1*) ,(alien 'b) ,lda 101 ,(alien 'a) ,ldb ,(alien '*blas-alien-0*) ,(alien 'c) m) 102 c)))) 103 104 (multiple-value-bind (da ka la) (syms ra ia) 105 (multiple-value-bind (db kb lb) (syms rb ib) 106 `(lambda (a b) 107 (declare (type (simple-array blas-float) a b)) 108 (destructuring-bind (,@da) (array-dimensions a) 109 (destructuring-bind (,@db) (array-dimensions b) 110 (unless (and ,@(loop for i in ka for j in kb collect `(= ,i ,j))) 111 (error "Array dimensions do not match.")) 112 (let ((k (* ,@ka)) 113 (n (* ,@la)) 114 (m (* ,@lb)) 115 (c (make-array (list ,@la ,@lb) :element-type 'blas-float))) 116 (declare (type (simple-array blas-float) c)) 117 ,@(let ((pa (indices-position ra ia)) 118 (pb (indices-position rb ib)) 119 (sa (indices-shape ia)) 120 (sb (indices-shape ib))) 121 (cond ((and pa pb) 122 (if (equal sa sb) 123 `(,(contract pa pb)) 124 `((if (< (array-total-size a) (array-total-size b)) 125 (progn 126 ,(perm 'a ra ia sb) 127 ,(contract :left pb)) 128 (progn 129 ,(perm 'b rb ib sa) 130 ,(contract pa :left)))))) 131 ((and pa (not pb)) 132 `(,(perm 'b rb ib sa) ,(contract pa :left))) 133 ((and (not pa) pb) 134 `(,(perm 'a ra ia sb) ,(contract :left pb))) 135 (t 136 (let ((s (loop for i below (length ia) collect i))) 137 `(,(perm 'a ra ia s) 138 ,(perm 'b rb ib s) 139 ,(contract :left :left)))))))))))))) 140 141 ;;; generate array contraction functions dynamically 142 143 (let ((cache (make-hash-table :test 'equal))) 144 (defun array-contract (a ia b ib) 145 (unless (listp ia) 146 (setf ia (list ia))) 147 (unless (listp ib) 148 (setf ib (list ib))) 149 (let ((key (list (array-rank a) ia (array-rank b) ib))) 150 (funcall (or (gethash key cache) 151 (setf (gethash key cache) 152 (symbol-function 153 (compile 154 (intern 155 (format nil "ARRAY-CONTRACT-~{R~D-~{~D~^-~}-R~D-~{~D~^-~}~}" key)) 156 (apply #'make-array-contract-fn key))))) 157 a 158 b)))) 159 160 ;; array decomposition 161 162 (defun array-decompose-zgesvd (m n a s u vt) 163 (let* ((min (min m n)) 164 (rwork (make-double-array (* 5 min)))) 165 (with-pinned-objects 166 ((array-storage-vector a) 167 (array-storage-vector s) 168 (array-storage-vector u) 169 (array-storage-vector vt) 170 (array-storage-vector rwork)) 171 (let ((aa (blas-array-alien a)) 172 (as (double-array-alien s)) 173 (au (blas-array-alien u)) 174 (avt (blas-array-alien vt)) 175 (arwork (double-array-alien rwork))) 176 (flet ((f (work lwork) 177 (with-pinned-objects ((array-storage-vector work)) 178 (let ((info 179 (nth-value 180 1 181 (zgesvd 182 "S" "S" m n aa m as au m avt min 183 (blas-array-alien work) lwork arwork)))) 184 (unless (zerop info) 185 (if (< info 0) 186 (error "Illegal value of parameter ~A in ZGESVD." 187 (- info)) 188 (error "ZGESVD did not converge (INFO = ~A)." 189 info))))))) 190 (let ((work (make-blas-array 1))) 191 (f work -1) 192 (let ((lwork (floor (realpart (aref work 0))))) 193 (f (make-blas-array lwork) lwork)))))))) 194 195 (defun array-decompose-zgesdd (m n a s u vt) 196 (let* ((min (min m n)) 197 (max (max m n)) 198 (rwork (make-double-array 199 (* min (max (+ (* 5 min) 5) (+ (* 2 max) (* 2 min) 1))))) 200 (iwork (make-integer-array (* 8 min)))) 201 (with-pinned-objects 202 ((array-storage-vector a) 203 (array-storage-vector s) 204 (array-storage-vector u) 205 (array-storage-vector vt) 206 (array-storage-vector rwork) 207 (array-storage-vector iwork)) 208 (let ((aa (blas-array-alien a)) 209 (as (double-array-alien s)) 210 (au (blas-array-alien u)) 211 (avt (blas-array-alien vt)) 212 (arwork (double-array-alien rwork)) 213 (aiwork (integer-array-alien iwork))) 214 (flet ((f (work lwork) 215 (with-pinned-objects ((array-storage-vector work)) 216 (let ((info 217 (nth-value 218 1 219 (zgesdd 220 "S" m n aa m as au m avt min 221 (blas-array-alien work) lwork arwork aiwork)))) 222 (unless (zerop info) 223 (if (< info 0) 224 (error "Illegal value of parameter ~A in ZGESDD." 225 (- info)) 226 (error "ZGESDD did not converge (INFO = ~A)." 227 info))))))) 228 (let ((work (make-blas-array 1))) 229 (f work -1) 230 (let ((lwork (floor (realpart (aref work 0))))) 231 (f (make-blas-array lwork) lwork)))))))) 232 233 (defun array-decompose-zgesvdx (m n a s u vt) 234 (let* ((min (min m n)) 235 (rwork (make-double-array (* 17 min min))) 236 (iwork (make-integer-array (* 12 min)))) 237 (with-pinned-objects 238 ((array-storage-vector a) 239 (array-storage-vector s) 240 (array-storage-vector u) 241 (array-storage-vector vt) 242 (array-storage-vector rwork) 243 (array-storage-vector iwork)) 244 (let ((aa (blas-array-alien a)) 245 (as (double-array-alien s)) 246 (au (blas-array-alien u)) 247 (avt (blas-array-alien vt)) 248 (arwork (double-array-alien rwork)) 249 (aiwork (integer-array-alien iwork))) 250 (flet ((f (work lwork) 251 (with-pinned-objects ((array-storage-vector work)) 252 (let ((info 253 (nth-value 254 2 255 (zgesvdx 256 "V" "V" "A" m n aa m 0d0 0d0 0 0 as au m avt min 257 (blas-array-alien work) lwork arwork aiwork)))) 258 (unless (zerop info) 259 (if (< info 0) 260 (error 261 "Illegal value of parameter ~A in ZGESVDX." 262 (- info)) 263 (if (eql info (1+ (* 2 n))) 264 (error 265 "Internal error occured in DBDSVDX/ZGESVDX.") 266 (error 267 "ZGESVDX did not converge (INFO = ~A)." 268 info)))))))) 269 (let ((work (make-blas-array 1))) 270 (f work -1) 271 (let ((lwork (floor (realpart (aref work 0))))) 272 (f (make-blas-array lwork) lwork)))))))) 273 274 (defun array-decompose-fallback (m n a s u vt) 275 (handler-case (array-decompose-zgesdd m n (array-copy a) s u vt) 276 (error (condition) 277 (format *error-output* "ZGESDD error: ~A~%Using ZGESVD.~%" condition) 278 (array-decompose-zgesvd m n a s u vt)))) 279 280 (defvar *svd-driver* :fallback) 281 282 (defun array-decompose (array indices) 283 (if (numberp indices) 284 (setf indices (list indices))) 285 ;; don't destroy the original array 286 (let* ((a (array-copy array)) 287 (r (array-rank a))) 288 289 ;; permute indices if they aren't in the correct order 290 (unless (and (eql (indices-position r indices) :left) 291 (apply #'< indices)) 292 (let ((l (loop for i below (length indices) collect i))) 293 (setf a (array-permute a (make-indices-permutation r indices l))) 294 (setf indices l))) 295 296 (let* ((d (array-dimensions a)) 297 (dl (loop 298 for i in indices 299 collect (nth i d))) 300 (dr (loop 301 for i below r 302 unless (member i indices) 303 collect (nth i d))) 304 (m (apply #'* dr)) 305 (n (apply #'* dl)) 306 (min (min m n)) 307 (s (make-double-array min)) 308 (u (make-blas-array (append (list min) dr))) 309 (vt (make-blas-array (append dl (list min))))) 310 311 (funcall 312 (ecase *svd-driver* 313 (:zgesvd #'array-decompose-zgesvd) 314 (:zgesdd #'array-decompose-zgesdd) 315 (:zgesvdx #'array-decompose-zgesvdx) 316 (:fallback #'array-decompose-fallback)) 317 m n a s u vt) 318 (values vt s u)))) 319 320 (defun array-addf (x y) 321 (unless (= (array-total-size x) (array-total-size y)) 322 (error "Array dimensions do not match.")) 323 (with-pinned-objects 324 ((array-storage-vector x) 325 (array-storage-vector y) 326 (array-storage-vector *blas-alien-1*)) 327 (zaxpy (array-total-size x) (blas-array-alien *blas-alien-1*) 328 (blas-array-alien y) 1 (blas-array-alien x) 1)))