#include int main () { int n = 0; long long h = 0LL; long long w = 0LL; int res = 0; long long ans = 0LL; long long mod_num = 998244353LL; long long tmp = 0LL; res = scanf("%lld", &h); res = scanf("%lld", &w); ans = (h*w)%mod_num; ans = (ans*ans)%mod_num; ans += mod_num-(h*w)%mod_num; tmp = ((1+h/2LL)*(h/2LL))%mod_num; tmp *= w; tmp %= mod_num; ans += mod_num-tmp%mod_num; tmp = ((1+w/2LL)*(w/2LL))%mod_num; tmp *= h; tmp %= mod_num; ans += mod_num-tmp%mod_num; tmp = ((1+h/2LL)*(h/2LL))%mod_num; tmp *= ((1+w/2LL)*(w/2LL))%mod_num; tmp %= mod_num; ans += mod_num-tmp; printf("%lld\n", ans%mod_num); return 0; }