h, w, k = map(int, input().split())
ans = 0
for _ in range(k):
    x, y, v = map(int, input().split())
    x -= 1
    y -= 1
    s = 0
    s += (x + 1) ** 2
    x1 = x - y - 1
    if x1 >= 0:
        s -= (x1 + 1) * (x1 + 2) // 2
    x2 = x - (w - y)
    if x2 >= 0:
        s -= (x2 + 1) * (x2 + 2) // 2
    ans = (ans + s * v) % 998244353
print(ans)