def solve(): N = H//2 M = W//2 A = N+2 B = M+2 ans = H*W%MOD*N%MOD*M%MOD ans -= A*B%MOD*N%MOD*M%MOD ans %= MOD ans += M*(M+1)//2*A%MOD*N%MOD ans %= MOD ans += N*(N+1)//2*B%MOD*M%MOD ans %= MOD ans -= (N*(N+1)//2)*(M*(M+1)//2)%MOD ans %= MOD ans *= 4 ans %= MOD return ans def solve2(): N = H//2 M = W//2 A = N+2 B = M+2 ans = H*W%MOD*N%MOD*M%MOD ans -= A*B%MOD*N%MOD*M%MOD ans %= MOD ans += M*(M+1)//2*A%MOD*N%MOD ans %= MOD ans += N*(N+1)//2*B%MOD*M%MOD ans %= MOD ans -= (N*(N+1)//2)*(M*(M+1)//2)%MOD ans %= MOD ans *= 4 ans %= MOD tmp = H*W%MOD*N%MOD tmp %= MOD tmp -= A*N%MOD tmp %= MOD tmp += N*(N+1)//2 tmp %= MOD tmp *= 2 tmp %= MOD return (ans+tmp)%MOD def solve3(): N = H//2 M = W//2 A = N+2 B = M+2 ans = H*W%MOD*N%MOD*M%MOD ans -= A*B%MOD*N%MOD*M%MOD ans %= MOD ans += M*(M+1)//2*A%MOD*N%MOD ans %= MOD ans += N*(N+1)//2*B%MOD*M%MOD ans %= MOD ans -= (N*(N+1)//2)*(M*(M+1)//2)%MOD ans %= MOD ans *= 4 ans %= MOD tmp = H*W%MOD*N%MOD tmp %= MOD tmp -= A*N%MOD tmp %= MOD tmp += N*(N+1)//2 tmp %= MOD tmp *= 2 tmp %= MOD tmp2 = H*W%MOD*M%MOD tmp2 %= MOD tmp2 -= B*M%MOD tmp2 %= MOD tmp2 += M*(M+1)//2 tmp2 %= MOD tmp2 *= 2 tmp2 %= MOD tmp3 = (H*W-1)%MOD return (ans+tmp+tmp2+tmp3)%MOD import sys input = sys.stdin.readline MOD = 998244353 H, W = map(int, input().split()) if H%2==0 and W%2==0: ans = solve() elif H%2==0: ans = solve2() elif W%2==0: H, W = W, H ans = solve2() else: ans = solve3() print(ans)