#include int main () { int n = 0; long long h = 0LL; long long w = 0LL; int res = 0; long long ans = 0LL; long long mod_num = 998244353LL; res = scanf("%lld", &h); res = scanf("%lld", &w); ans = ((h/2LL)*(h/2LL))%mod_num; ans *= ((w/2LL)*(w/2LL))%mod_num; ans %= mod_num; ans += mod_num-((h/2LL)*(w/2LL))%mod_num; ans *= 4LL; ans %= mod_num; if (w%2LL == 1LL) { long long tmp = ((h/2LL)*(h/2LL))%mod_num; tmp *= w; tmp %= mod_num; tmp += mod_num-((h/2LL)%mod_num); ans += tmp*2LL; } if (h%2LL == 1LL) { long long tmp = ((w/2LL)*(w/2LL))%mod_num; tmp *= h; tmp %= mod_num; tmp += mod_num-((w/2LL)%mod_num); ans += tmp*2LL; } if (h%2LL == 1LL && w%2LL == 1LL) { long long tmp = (h*w)%mod_num; tmp += mod_num-1LL; ans += tmp; } printf("%lld\n", ans%mod_num); return 0; }