#!/usr/bin/env python3 # # No.1304 あなたは基本が何か知っていますか?私は知っています. # import sys, os, math def read_ints(): return list(map(int, input().split())) MOD = 998244353 n, k, x, y = read_ints() a = set(read_ints()) k = len(a) m = 1 << math.ceil(math.log2(max(a) + 1)) dp = [[0] * m for _ in range(n + 1)]; dp[0][0] = 1 for i in range(1, n + 1): for j in range(m): for t in a: dp[i][j ^ t] = (dp[i][j ^ t] + dp[i - 1][j]) % MOD if i == 2: dp[i][j] -= dp[i - 2][j] * k elif i > 2: dp[i][j] -= dp[i - 2][j] * (k - 1) dp[i][j] %= MOD print(sum(dp[n][x: y + 1]) % MOD)