#!/usr/bin/env python3 def solve(n, k): assert n % k == 0 if n * (n + 1) // 2 % k != 0: return None if n == 1: return [[1]] if n // k == 1: return None z = n * (n + 1) // 2 // k a = [[] for _ in range(k)] while n // k: if n // k == 3: for i in range(k): a[i] += [n] n -= 1 for i in range(k): a[(k // 2 + i) % k] += [n] n -= 1 for i in range(k): a[i] += [z - sum(a[i])] n -= 1 else: for i in range(k): a[i] += [n] n -= 1 for i in range(k): a[k - i - 1] += [n] n -= 1 assert len(set(map(sum, a))) == 1 return a def main(): n, k = map(int, input().split()) answer = solve(n, k) if answer: print('Yes') for it in answer: print(*it) else: print('No') if __name__ == "__main__": main()