#include <bits/stdc++.h>
#include <atcoder/all>
using namespace std;
using ll = long long;
#define rep(i,n) for(int i=0;i<(int)(n);i++)
using mint = atcoder::modint998244353;

int main(){
	ll h,w;
	cin>>h>>w;
	ll ch=(h/2),cw=(w/2);
	mint ans=0;
	mint td=h*w-(ch+1)*(cw+1);
	ans+=(td*ch*cw)*4;
	ans+=(mint(ch+1)*(cw+1)*ch*cw)*4;
	ans-=((mint(ch+1)*(ch+1)+ch-1)*(mint(cw+1)*(cw+1)+cw-1));
	if(h&1){
		ans+=mint(cw)*h*w*2-cw*(cw+3);
	}
	if(w&1){
		ans+=mint(ch)*w*h*2-ch*(ch+3);
	}
	if((h&1)&&(w&1)){
		ans+=mint(h)*w-1;
	}
	cout<<ans.val()<<endl;
}