import sys def main(): data = sys.stdin.read().split() N = int(data[0]); M = int(data[1]) A = list(map(int, data[2:2+N])) mod = 998244353 # Precompute M^(N-1) and M^N mod M_pow_N_1 = pow(M, N-1, mod) M_pow_N = M_pow_N_1 * M % mod # Sum of initial matches over all B: sum_i (M - A_i) * M^(N-1) sum1 = 0 for a in A: sum1 = (sum1 + (M - a) % mod) % mod sum1 = sum1 * M_pow_N_1 % mod # If no swap is possible if N == 1: print(sum1) return # Prepare intervals I_i = (min(A_i, A_{i+1}), max(A_i, A_{i+1})) L = [0] * (N-1) R = [0] * (N-1) d = [0] * (N-1) for i in range(N-1): u, v = A[i], A[i+1] if u < v: L[i], R[i] = u, v else: L[i], R[i] = v, u d[i] = R[i] - L[i] # <-- typo removed # DP over intervals: state is count of sequences where B_i in I_i (1) or not (0) d0 = d[0] dp0 = [(M - d0) % mod, d0 % mod] for i in range(N-2): Li, Ri, di = L[i], R[i], d[i] Lj, Rj, dj = L[i+1], R[i+1], d[i+1] inter = max(0, min(Ri, Rj) - max(Li, Lj)) size11 = inter size10 = di - inter size01 = dj - inter size00 = M - (size11 + size10 + size01) T00, T01 = size00 % mod, size01 % mod T10, T11 = size10 % mod, size11 % mod dp1_0 = (dp0[0] * T00 + dp0[1] * T10) % mod dp1_1 = (dp0[0] * T01 + dp0[1] * T11) % mod dp0 = [dp1_0, dp1_1] # Final position N: match same state d_last = d[N-2] cnt_end0 = (M - d_last) % mod cnt_end1 = d_last % mod total_no_swap_gain = (dp0[0] * cnt_end0 + dp0[1] * cnt_end1) % mod extra = (M_pow_N - total_no_swap_gain) % mod answer = (sum1 + extra) % mod print(answer) if __name__ == '__main__': main()