h,w = map(int,input().split()) def f(n): return (n//2)*(n//2+1)+n print((h*w*(h*w-1)-f(h)*f(w)+h*w)%998244353)