(eval-when (:compile-toplevel :load-toplevel :execute) (sb-int:defconstant-eqx opt #+swank '(optimize (speed 3) (safety 2)) #-swank '(optimize (speed 3) (safety 0) (debug 0)) #'equal) #+swank (ql:quickload '(:cl-debug-print :fiveam) :silent t) #-swank (set-dispatch-macro-character #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t))))) #+swank (cl-syntax:use-syntax cl-debug-print:debug-print-syntax) #-swank (disable-debugger) ; for CS Academy ;; BEGIN_INSERTED_CONTENTS (declaim (ftype (function * (values fixnum &optional)) read-fixnum)) (defun read-fixnum (&optional (in *standard-input*)) "NOTE: cannot read -2^62" (macrolet ((%read-byte () `(the (unsigned-byte 8) #+swank (char-code (read-char in nil #\Nul)) #-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil)))) (let* ((minus nil) (result (loop (let ((byte (%read-byte))) (cond ((<= 48 byte 57) (return (- byte 48))) ((zerop byte) ; #\Nul (error "Read EOF or #\Nul.")) ((= byte #.(char-code #\-)) (setq minus t))))))) (declare ((integer 0 #.most-positive-fixnum) result)) (loop (let* ((byte (%read-byte))) (if (<= 48 byte 57) (setq result (+ (- byte 48) (* 10 (the (integer 0 #.(floor most-positive-fixnum 10)) result)))) (return (if minus (- result) result)))))))) ;;; ;;; Fast Number Theoretic Transform ;;; Reference: https://kopricky.github.io/code/FFTs/ntt.html ;;; (defconstant +ntt-mod+ 998244353) (defconstant +ntt-root+ 3) (deftype ntt-int () '(unsigned-byte 31)) (deftype ntt-vector () '(simple-array ntt-int (*))) (eval-when (:compile-toplevel :load-toplevel :execute) (assert (typep +ntt-mod+ 'ntt-int))) ;; FIXME: Here I resort to SBCL's behaviour. Actually ADJUST-ARRAY isn't ;; guaranteed to preserve the given VECTOR. (declaim (ftype (function * (values ntt-vector &optional)) %adjust-array)) (defun %adjust-array (vector length) (declare (vector vector)) (let ((vector (coerce vector 'ntt-vector))) (if (= (length vector) length) (copy-seq vector) (adjust-array vector length :initial-element 0)))) (declaim (ftype (function * (values ntt-vector &optional)) ntt!)) (defun ntt! (vector &optional inverse) (declare #.OPT (vector vector)) (labels ((power2-p (x) "Returns true iff X is a power of 2" (zerop (logand x (- x 1)))) (mod* (x y) (declare (ntt-int x y)) (mod (* x y) +ntt-mod+)) (mod+ (x y) (declare (ntt-int x y)) (let ((res (+ x y))) (if (>= res +ntt-mod+) (- res +ntt-mod+) res))) (mod- (x y) (declare (ntt-int x y)) (mod+ x (- +ntt-mod+ y))) (mod-power (base exp) (declare (ntt-int base) ((integer 0 #.most-positive-fixnum) exp)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (mod* res base)) do (setq base (mod* base base) exp (ash exp -1))) res)) (mod-inverse (x) (mod-power x (- +ntt-mod+ 2)))) (declare (inline mod* mod+ mod-)) (let ((len (length vector))) (assert (power2-p len)) (check-type len ntt-int)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (roots (make-array (+ 1 (ash len -1)) :element-type 'ntt-int :initial-element 1)) (tmp (make-array len :element-type 'ntt-int))) (declare ((simple-array ntt-int (*)) vector tmp) (ntt-int len)) (when (<= len 1) (return-from ntt! vector)) (let ((root (mod-power +ntt-root+ (if inverse (- +ntt-mod+ 1 (floor (- +ntt-mod+ 1) len)) (floor (- +ntt-mod+ 1) len))))) (dotimes (i (ash len -1)) (setf (aref roots (+ i 1)) (mod* (aref roots i) root)))) (loop for i of-type ntt-int = 1 then (ash i 1) for l of-type ntt-int = (ash len -1) then (ash l -1) while (< i len) do (loop for j from 0 below l for r of-type ntt-int = 0 then (+ r i) do (loop for k below i for root = (aref roots (* i j)) for p = (aref vector (+ k r)) for q = (aref vector (+ k r (ash len -1))) do (setf (aref tmp (+ k (* 2 r))) (mod+ p q) (aref tmp (+ k (* 2 r) i)) (mod* (mod- p q) root)))) (rotatef vector tmp)) (when inverse (let ((inv (mod-inverse len))) (dotimes (i len) (setf (aref vector i) (mod* (aref vector i) inv))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ntt-convolute!)) (defun ntt-convolute (vector1 vector2 &optional fixed) (declare (optimize (speed 3)) (vector vector1 vector2)) (let ((len1 (length vector1)) (len2 (length vector1))) (when fixed (assert (= len1 len2))) (let* ((mul-len (max 0 (- (+ len1 len2) 1))) ;; power of two ceiling (required-len (if fixed len1 (ash 1 (integer-length (max 0 (- mul-len 1)))))) (vector1 (ntt! (%adjust-array vector1 required-len))) (vector2 (ntt! (%adjust-array vector2 required-len))) (res (make-array required-len :element-type 'ntt-int :initial-element 0))) (dotimes (i required-len) (setf (aref res i) (mod (* (aref vector1 i) (aref vector2 i)) +ntt-mod+))) (ntt! res t)))) ;;; ;;; Arithmetic operations with static modulus ;;; ;; FIXME: Currently MOD* and MOD+ doesn't apply MOD when the number of ;; parameters is one. (defmacro define-mod-operations (divisor) `(progn (defun mod* (&rest args) (reduce (lambda (x y) (mod (* x y) ,divisor)) args)) (defun mod+ (&rest args) (reduce (lambda (x y) (mod (+ x y) ,divisor)) args)) #+sbcl (eval-when (:compile-toplevel :load-toplevel :execute) (locally (declare (muffle-conditions warning)) (sb-c:define-source-transform mod* (&rest args) (if (null args) 1 (reduce (lambda (x y) `(mod (* ,x ,y) ,',divisor)) args))) (sb-c:define-source-transform mod+ (&rest args) (if (null args) 0 (reduce (lambda (x y) `(mod (+ ,x ,y) ,',divisor)) args))))) (define-modify-macro incfmod (delta) (lambda (x y) (mod (+ x y) ,divisor))) (define-modify-macro decfmod (delta) (lambda (x y) (mod (- x y) ,divisor))) (define-modify-macro mulfmod (multiplier) (lambda (x y) (mod (* x y) ,divisor))))) (in-package :cl-user) (defmacro dbg (&rest forms) #+swank (if (= (length forms) 1) `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms)) `(format *error-output* "~A => ~A~%" ',forms `(,,@forms))) #-swank (declare (ignore forms))) (defmacro define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b))) bits) ,@(mapcar (lambda (b) `(deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b))) bits))) (define-int-types 2 4 7 8 15 16 31 32 62 63 64) (declaim (inline println)) (defun println (obj &optional (stream *standard-output*)) (let ((*read-default-float-format* 'double-float)) (prog1 (princ obj stream) (terpri stream)))) (defconstant +mod+ 998244353) ;;; ;;; Body ;;; (define-mod-operations +mod+) (defun main () (let* ((n (read)) (n2 (sb-int:power-of-two-ceiling n)) (q (read)) (polys (make-array n2 :element-type '(simple-array uint31 (*))))) (declare (uint31 n q)) (dotimes (i n2) (if (< i n) (let ((a (read-fixnum)) (poly (make-array 2 :element-type 'uint31))) (setf (aref poly 1) (mod (- a 1) +mod+) (aref poly 0) 1) (setf (aref polys i) poly)) (let ((poly (make-array 2 :element-type 'uint31))) (setf (aref poly 1) 0 (aref poly 0) 1) (setf (aref polys i) poly)))) (loop for width = 1 then (ash width 1) while (< width n2) do (loop for i from 0 below n2 by (* width 2) do (setf (aref polys i) (ntt-convolute (aref polys i) (aref polys (+ i width)))))) (let ((poly (aref polys 0))) (dotimes (i q) (let ((b (read-fixnum))) (println (aref poly (- n b)))))))) #-swank (main) ;;; ;;; Test and benchmark ;;; #+swank (defun io-equal (in-string out-string &key (function #'main) (test #'equal)) "Passes IN-STRING to *STANDARD-INPUT*, executes FUNCTION, and returns true if the string output to *STANDARD-OUTPUT* is equal to OUT-STRING." (labels ((ensure-last-lf (s) (if (eql (uiop:last-char s) #\Linefeed) s (uiop:strcat s uiop:+lf+)))) (funcall test (ensure-last-lf out-string) (with-output-to-string (out) (let ((*standard-output* out)) (with-input-from-string (*standard-input* (ensure-last-lf in-string)) (funcall function))))))) #+swank (defun get-clipbrd () (with-output-to-string (out) (run-program "powershell.exe" '("-Command" "Get-Clipboard") :output out :search t))) #+swank (defparameter *this-pathname* (uiop:current-lisp-file-pathname)) #+swank (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *this-pathname*)) #+swank (defun run (&optional thing (out *standard-output*)) "THING := null | string | symbol | pathname null: run #'MAIN using the text on clipboard as input. string: run #'MAIN using the string as input. symbol: alias of FIVEAM:RUN!. pathname: run #'MAIN using the text file as input." (let ((*standard-output* out)) (etypecase thing (null (with-input-from-string (*standard-input* (delete #\Return (get-clipbrd))) (main))) (string (with-input-from-string (*standard-input* (delete #\Return thing)) (main))) (symbol (5am:run! thing)) (pathname (with-open-file (*standard-input* thing) (main)))))) #+swank (defun gen-dat () (uiop:with-output-file (out *dat-pathname* :if-exists :supersede) (format out ""))) #+swank (defun bench (&optional (out (make-broadcast-stream))) (time (run *dat-pathname* out))) ;; To run: (5am:run! :sample) #+swank (it.bese.fiveam:test :sample (it.bese.fiveam:is (common-lisp-user::io-equal "3 4 3 4 5 0 1 2 3 " "24 26 9 1 ")) (it.bese.fiveam:is (common-lisp-user::io-equal "4 3 1 3 3 3 3 1 4 " "6 8 1 ")) (it.bese.fiveam:is (common-lisp-user::io-equal "3 2 5 7 5 1 3 " "64 1 ")))