mod = 998244353 def calch(h): res = (H-1)*(h+1) - h*(h+1)//2 res %= mod return res def calcw(w): res = (W-1)*(w+1) - w*(w+1)//2 res %= mod return res def calc(h,w): res1 = calch(h) res2 = calcw(w) res = res1 * res2 res %= mod return res H,W = map(int,input().split()) HH = (H-1)//2 - 1 WW = (W-1)//2 - 1 ans = 0 # H軸に平行 ans += W*calch(HH)*2 ans %= mod # W軸に平行 ans += H*calcw(WW)*2 ans %= mod # どっちでもない ans += (calc(H-1,WW) + calc(HH,W-1) - calc(HH,WW)) * 4 ans %= mod print(ans)