h,w = map(int,input().split()) def f(n): return (n//2)*(n//2+1)+n print((h*w*h*w-f(h)*f(w))%998244353)