import sys import math sys.setrecursionlimit(10 ** 7) def input() : return sys.stdin.readline().strip() def INT() : return int(input()) def MAP() : return map(int,input().split()) def LIST() : return list(MAP()) def NIJIGEN(H): return [list(input()) for i in range(H)] N,M=map(int,input().split()) if N%M==0: a=N//M else: a=(N//M)+1 n=pow(2,a,998244353)-1 if n==-1: print(n+998244353) else: print(n)