mod = 998244353 N, K = [int(x) for x in input().split()] ans = N * K ans *= K - 1 ans %= mod ans *= pow(2, (mod - 2) * N, mod) ans %= mod print(ans)