MOD=998244353 def solve(a): n=len(a) if n==0: return 0 m=len(a[0]) if True: dp=[0]*(1<>k)&1: ndp[j-(1<>j)&1)==0: if j==m-1 or a[i][j+1]!=a[0][0]: dp[m][k+(1<