(eval-when (:compile-toplevel :load-toplevel :execute) (sb-int:defconstant-eqx opt #+swank '(optimize (speed 3) (safety 2)) #-swank '(optimize (speed 3) (safety 0) (debug 0)) #'equal) #+swank (ql:quickload '(:cl-debug-print :fiveam) :silent t) #-swank (set-dispatch-macro-character #\# #\> (lambda (s c p) (declare (ignore c p)) `(values ,(read s nil nil t))))) #+swank (cl-syntax:use-syntax cl-debug-print:debug-print-syntax) #-swank (disable-debugger) ; for CS Academy ;; BEGIN_INSERTED_CONTENTS ;;; ;;; Fast Number Theoretic Transform ;;; Reference: https://github.com/ei1333/library/blob/master/math/fft/number-theoretic-transform-friendly-mod-int.cpp ;;; (defconstant +ntt-mod+ 998244353) (defconstant +ntt-root+ 3) (deftype ntt-int () '(unsigned-byte 31)) (deftype ntt-vector () '(simple-array ntt-int (*))) (eval-when (:compile-toplevel :load-toplevel :execute) (assert (typep +ntt-mod+ 'ntt-int))) (declaim (inline %tzcount)) (defun %tzcount (x) (- (integer-length (logand x (- x))) 1)) (declaim (inline %mod-power)) (defun %mod-power (base exp) (declare (ntt-int base) ((integer 0 #.most-positive-fixnum) exp)) (let ((res 1)) (declare (ntt-int res)) (loop while (> exp 0) when (oddp exp) do (setq res (mod (* res base) +ntt-mod+)) do (setq base (mod (* base base) +ntt-mod+) exp (ash exp -1))) res)) (defun check-ntt-vector (vector) (declare (optimize (speed 3)) (vector vector)) (let ((len (length vector))) (assert (zerop (logand len (- len 1)))) ;; power of two (check-type len ntt-int))) (defun make-ntt-base () (labels ((mod-inverse (x) (%mod-power x (- +ntt-mod+ 2)))) (let* ((base-size (%tzcount (- +ntt-mod+ 1))) (base (make-array base-size :element-type 'ntt-int)) (inv-base (make-array base-size :element-type 'ntt-int))) (dotimes (i base-size) (setf (aref base i) (mod (- (%mod-power +ntt-root+ (ash (- +ntt-mod+ 1) (- (+ i 2))))) +ntt-mod+) (aref inv-base i) (mod-inverse (aref base i)))) (values base inv-base)))) (multiple-value-bind (base inv-base) (make-ntt-base) (defparameter *ntt-base* base) (defparameter *ntt-inv-base* inv-base)) ;; FIXME: Here I resort to SBCL's behaviour. Actually ADJUST-ARRAY isn't ;; guaranteed to preserve the given VECTOR. (declaim (ftype (function * (values ntt-vector &optional)) %adjust-array)) (defun %adjust-array (vector length) (declare (vector vector)) (let ((vector (coerce vector 'ntt-vector))) (if (= (length vector) length) (copy-seq vector) (adjust-array vector length :initial-element 0)))) (declaim (ftype (function * (values ntt-vector &optional)) ntt!)) (defun ntt! (vector) (declare #.OPT (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (mod (* x y) +ntt-mod+)) (mod+ (x y) (let ((res (+ x y))) (if (>= res +ntt-mod+) (- res +ntt-mod+) res))) (mod- (x y) (mod+ x (- +ntt-mod+ y)))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base *ntt-base*)) (declare ((simple-array ntt-int (*)) vector base) (ntt-int len)) (when (<= len 1) (return-from ntt! vector)) (loop for m of-type ntt-int = (ash len -1) then (ash m -1) while (> m 0) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (mod* (aref vector j) w) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod- x y))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) vector))) (defun inverse-ntt! (vector &optional inverse) (declare #.OPT (vector vector)) (check-ntt-vector vector) (labels ((mod* (x y) (declare (ntt-int x y)) (mod (* x y) +ntt-mod+)) (mod+ (x y) (declare (ntt-int x y)) (let ((res (+ x y))) (if (>= res +ntt-mod+) (- res +ntt-mod+) res))) (mod- (x y) (declare (ntt-int x y)) (mod+ x (- +ntt-mod+ y)))) (declare (inline mod* mod+ mod-)) (let* ((vector (coerce vector 'ntt-vector)) (len (length vector)) (base *ntt-inv-base*)) (declare ((simple-array ntt-int (*)) vector base) (ntt-int len)) (when (<= len 1) (return-from inverse-ntt! vector)) (loop for m of-type ntt-int = 1 then (ash m 1) while (< m len) for w of-type ntt-int = 1 for k of-type ntt-int = 0 do (loop for s of-type ntt-int from 0 below len by (* 2 m) do (loop for i from s below (+ s m) for j from (+ s m) for x = (aref vector i) for y = (aref vector j) do (setf (aref vector i) (mod+ x y) (aref vector j) (mod* (mod- x y) w))) (incf k) (setq w (mod* w (aref base (%tzcount k)))))) (when inverse (let ((inv-len (%mod-power len (- +ntt-mod+ 2)))) (dotimes (i len) (setf (aref vector i) (mod* inv-len (aref vector i)))))) vector))) (declaim (ftype (function * (values ntt-vector &optional)) ntt-convolute!)) (defun ntt-convolute! (vector1 vector2) (declare (ntt-vector vector1 vector2)) (let* ((len1 (length vector1)) (len2 (length vector2)) (mul-len (- (+ len1 len2) 1)) (required-len (sb-int:power-of-two-ceiling mul-len)) (vector1 (ntt! (adjust-array vector1 required-len))) (vector2 (ntt! (adjust-array vector2 required-len)))) (dotimes (i required-len) (setf (aref vector1 i) (mod (* (aref vector1 i) (aref vector2 i)) +ntt-mod+))) (inverse-ntt! vector1 t))) ;; NOTE: buggy (declaim (ftype (function * (values ntt-vector &optional)) ntt-convolute)) (defun ntt-convolute (vector1 vector2 &optional fixed) (declare (optimize (speed 3)) (vector vector1 vector2)) (let ((len1 (length vector1)) (len2 (length vector1))) (when fixed (assert (= len1 len2))) (let* ((mul-len (max 0 (- (+ len1 len2) 1))) ;; power of two ceiling (required-len (if fixed len1 (ash 1 (integer-length (max 0 (- mul-len 1)))))) (vector1 (ntt! (%adjust-array vector1 required-len))) (vector2 (ntt! (%adjust-array vector2 required-len)))) (dotimes (i required-len) (setf (aref vector1 i) (mod (* (aref vector1 i) (aref vector2 i)) +ntt-mod+))) (inverse-ntt! vector1)))) (declaim (ftype (function * (values fixnum &optional)) read-fixnum)) (defun read-fixnum (&optional (in *standard-input*)) "NOTE: cannot read -2^62" (declare #.opt) (macrolet ((%read-byte () `(the (unsigned-byte 8) #+swank (char-code (read-char in nil #\Nul)) #-swank (sb-impl::ansi-stream-read-byte in nil #.(char-code #\Nul) nil)))) (let* ((minus nil) (result (loop (let ((byte (%read-byte))) (cond ((<= 48 byte 57) (return (- byte 48))) ((zerop byte) ; #\Nul (error "Read EOF or #\Nul.")) ((= byte #.(char-code #\-)) (setq minus t))))))) (declare ((integer 0 #.most-positive-fixnum) result)) (loop (let* ((byte (%read-byte))) (if (<= 48 byte 57) (setq result (+ (- byte 48) (* 10 (the (integer 0 #.(floor most-positive-fixnum 10)) result)))) (return (if minus (- result) result)))))))) ;;; ;;; Arithmetic operations with static modulus ;;; ;; FIXME: Currently MOD* and MOD+ doesn't apply MOD when the number of ;; parameters is one. (defmacro define-mod-operations (divisor) `(progn (defun mod* (&rest args) (reduce (lambda (x y) (mod (* x y) ,divisor)) args)) (defun mod+ (&rest args) (reduce (lambda (x y) (mod (+ x y) ,divisor)) args)) #+sbcl (eval-when (:compile-toplevel :load-toplevel :execute) (locally (declare (muffle-conditions warning)) (sb-c:define-source-transform mod* (&rest args) (if (null args) 1 (reduce (lambda (x y) `(mod (* ,x ,y) ,',divisor)) args))) (sb-c:define-source-transform mod+ (&rest args) (if (null args) 0 (reduce (lambda (x y) `(mod (+ ,x ,y) ,',divisor)) args))))) (define-modify-macro incfmod (delta) (lambda (x y) (mod (+ x y) ,divisor))) (define-modify-macro decfmod (delta) (lambda (x y) (mod (- x y) ,divisor))) (define-modify-macro mulfmod (multiplier) (lambda (x y) (mod (* x y) ,divisor))))) (in-package :cl-user) (defmacro dbg (&rest forms) #+swank (if (= (length forms) 1) `(format *error-output* "~A => ~A~%" ',(car forms) ,(car forms)) `(format *error-output* "~A => ~A~%" ',forms `(,,@forms))) #-swank (declare (ignore forms))) (defmacro define-int-types (&rest bits) `(progn ,@(mapcar (lambda (b) `(deftype ,(intern (format nil "UINT~A" b)) () '(unsigned-byte ,b))) bits) ,@(mapcar (lambda (b) `(deftype ,(intern (format nil "INT~A" b)) () '(signed-byte ,b))) bits))) (define-int-types 2 4 7 8 15 16 31 32 62 63 64) (declaim (inline println)) (defun println (obj &optional (stream *standard-output*)) (let ((*read-default-float-format* 'double-float)) (prog1 (princ obj stream) (terpri stream)))) (defconstant +mod+ 998244353) ;;; ;;; Body ;;; (define-mod-operations +mod+) (defun main () (declare #.OPT) (let* ((n (read)) (n2 (sb-int:power-of-two-ceiling n)) (q (read)) (polys (make-array n2 :element-type 'ntt-vector))) (declare (uint31 n q)) (dotimes (i n2) (if (< i n) (let ((a (read-fixnum)) (poly (make-array 2 :element-type 'uint31))) (setf (aref poly 1) (mod (- a 1) +mod+) (aref poly 0) 1) (setf (aref polys i) poly)) (let ((poly (make-array 2 :element-type 'uint31))) (setf (aref poly 1) 0 (aref poly 0) 1) (setf (aref polys i) poly)))) (loop for width = 1 then (ash width 1) while (< width n2) do (loop for i from 0 below n2 by (* width 2) do (setf (aref polys i) (ntt-convolute! (aref polys i) (aref polys (+ i width)))))) (let ((poly (aref polys 0))) (declare (ntt-vector poly)) (write-string (with-output-to-string (*standard-output* nil :element-type 'base-char) (dotimes (i q) (let ((b (read-fixnum))) (println (aref poly (- n b)))))))))) #-swank (main) ;;; ;;; Test and benchmark ;;; #+swank (defun io-equal (in-string out-string &key (function #'main) (test #'equal)) "Passes IN-STRING to *STANDARD-INPUT*, executes FUNCTION, and returns true if the string output to *STANDARD-OUTPUT* is equal to OUT-STRING." (labels ((ensure-last-lf (s) (if (eql (uiop:last-char s) #\Linefeed) s (uiop:strcat s uiop:+lf+)))) (funcall test (ensure-last-lf out-string) (with-output-to-string (out) (let ((*standard-output* out)) (with-input-from-string (*standard-input* (ensure-last-lf in-string)) (funcall function))))))) #+swank (defun get-clipbrd () (with-output-to-string (out) (run-program "powershell.exe" '("-Command" "Get-Clipboard") :output out :search t))) #+swank (defparameter *this-pathname* (uiop:current-lisp-file-pathname)) #+swank (defparameter *dat-pathname* (uiop:merge-pathnames* "test.dat" *this-pathname*)) #+swank (defun run (&optional thing (out *standard-output*)) "THING := null | string | symbol | pathname null: run #'MAIN using the text on clipboard as input. string: run #'MAIN using the string as input. symbol: alias of FIVEAM:RUN!. pathname: run #'MAIN using the text file as input." (let ((*standard-output* out)) (etypecase thing (null (with-input-from-string (*standard-input* (delete #\Return (get-clipbrd))) (main))) (string (with-input-from-string (*standard-input* (delete #\Return thing)) (main))) (symbol (5am:run! thing)) (pathname (with-open-file (*standard-input* thing) (main)))))) #+swank (defun gen-dat () (uiop:with-output-file (out *dat-pathname* :if-exists :supersede) (format out ""))) #+swank (defun bench (&optional (out (make-broadcast-stream))) (time (run *dat-pathname* out))) ;; To run: (5am:run! :sample) #+swank (it.bese.fiveam:test :sample (it.bese.fiveam:is (common-lisp-user::io-equal "3 4 3 4 5 0 1 2 3 " "24 26 9 1 ")) (it.bese.fiveam:is (common-lisp-user::io-equal "4 3 1 3 3 3 3 1 4 " "6 8 1 ")) (it.bese.fiveam:is (common-lisp-user::io-equal "3 2 5 7 5 1 3 " "64 1 ")))