/* -*- coding: utf-8 -*- * * 866.cc: No.866 レベルKの正方形 - yukicoder */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; /* constant */ const int MAX_H = 2000; const int MAX_W = 2000; const int MAX_E2 = 1 << 12; // = 4096 const int BN = 26 / 2; const int BBITS = 1 << BN; /* typedef */ typedef long long ll; template struct SegTreeOr2D { int n, e2; T nodes[MAX_E2][MAX_E2]; SegTreeOr2D() {} void init(int _n) { n = _n; for (e2 = 1; e2 < n; e2 <<= 1); } T &get(int i0, int i1) { return nodes[e2 - 1 + i0][e2 - 1 + i1]; } void set(int i0, int i1, T v) { get(i0, i1) = v; } void setall() { for (int j0 = e2 - 2; j0 >= 0; j0--) { int k00 = j0 * 2 + 1, k01 = k00 + 1; for (int j1 = e2 - 2; j1 >= 0; j1--) { int k10 = j1 * 2 + 1, k11 = k10 + 1; nodes[j0][j1] = nodes[k00][k10] | nodes[k00][k11] | nodes[k01][k10] | nodes[k01][k11]; } } } T or_range(int r00, int r01, int r10, int r11, int k0, int k1, int i00, int i01, int i10, int i11) { if (r01 <= i00 || i01 <= r00 || r11 <= i10 || i11 <= r10) return 0; if (r00 <= i00 && i01 <= r01 && r10 <= i10 && i11 <= r11) return nodes[k0][k1]; int im0 = (i00 + i01) / 2, im1 = (i10 + i11) / 2; int k00 = k0 * 2 + 1, k01 = k00 + 1; int k10 = k1 * 2 + 1, k11 = k10 + 1; T v00 = or_range(r00, r01, r10, r11, k00, k10, i00, im0, i10, im1); T v01 = or_range(r00, r01, r10, r11, k00, k11, i00, im0, im1, i11); T v10 = or_range(r00, r01, r10, r11, k01, k10, im0, i01, i10, im1); T v11 = or_range(r00, r01, r10, r11, k01, k11, im0, i01, im1, i11); return v00 | v01 | v10 | v11; } T or_range(int r00, int r01, int r10, int r11) { return or_range(r00, r01, r10, r11, 0, 0, 0, e2, 0, e2); } }; /* global variables */ int bnums[BBITS]; SegTreeOr2D st; char s[MAX_W + 4]; /* subroutines */ inline int bitnum(int bits) { return bnums[bits & (BBITS - 1)] + bnums[bits >> BN]; } inline int stlb(int i, int j, int maxl, int k) { int l0 = 0, l1 = maxl + 1; while (l0 + 1 < l1) { int l = (l0 + l1) / 2; if (bitnum(st.or_range(i, i + l, j, j + l)) >= k) l1 = l; else l0 = l; } return l1; } /* main */ int main() { bnums[0] = 0; for (int bits = 1, msb = 1; bits < BBITS; bits++) { if ((msb << 1) <= bits) msb <<= 1; bnums[bits] = bnums[bits ^ msb] + 1; } int h, w, k; scanf("%d%d%d", &h, &w, &k); st.init(max(h, w)); for (int i = 0; i < h; i++) { scanf("%s", s); for (int j = 0; j < w; j++) st.set(i, j, 1 << (s[j] - 'a')); } st.setall(); ll sum = 0; for (int i = 0; i < h; i++) for (int j = 0; j < w; j++) { int maxl = min(h - i, w - j); int l0 = stlb(i, j, maxl, k); if (l0 <= maxl) { int l1 = stlb(i, j, maxl, k + 1); sum += l1 - l0; } } printf("%lld\n", sum); return 0; }