def solve1(H: int, W: int) -> int: n = (H - 1) // 2 return (H - 1 + (H - n)) * n // 2 * W * 2 def solve(H: int, W: int) -> int: ret = solve1(H, W) + solve1(W, H) h = H // 2 + 1 w = W // 2 + 1 ret += H * (H - 1) * W * (W - 1) - (h - 1) * h * (w - 1) * w return ret H, W = map(int, input().split()) print(solve(H, W) % 998244353)