#include int main () { int n = 0; long long h = 0LL; long long w = 0LL; int res = 0; long long ans = 0LL; long long mod_num = 998244353LL; res = scanf("%lld", &h); res = scanf("%lld", &w); ans = ((h/2LL)*(h/2LL))%mod_num; ans *= ((w/2LL)*(w/2LL))%mod_num; ans %= mod_num; ans += mod_num-((h/2LL)*(w/2LL))%mod_num; ans *= 4LL; ans %= mod_num; if (h%2LL == 1LL) { long long tmp = 0LL; } printf("%lld\n", ans); return 0; }