結果

問題 No.1054 Union add query
ユーザー tarattata1tarattata1
提出日時 2020-05-16 05:53:39
言語 C++14
(gcc 12.3.0 + boost 1.83.0)
結果
AC  
実行時間 255 ms / 2,000 ms
コード長 3,651 bytes
コンパイル時間 881 ms
コンパイル使用メモリ 89,736 KB
実行使用メモリ 20,560 KB
最終ジャッジ日時 2024-09-21 15:53:42
合計ジャッジ時間 3,730 ms
ジャッジサーバーID
(参考情報)
judge4 / judge1
このコードへのチャレンジ
(要ログイン)

テストケース

テストケース表示
入力 結果 実行時間
実行使用メモリ
testcase_00 AC 2 ms
6,812 KB
testcase_01 AC 2 ms
6,944 KB
testcase_02 AC 2 ms
6,944 KB
testcase_03 AC 184 ms
11,008 KB
testcase_04 AC 255 ms
20,536 KB
testcase_05 AC 161 ms
9,984 KB
testcase_06 AC 174 ms
13,568 KB
testcase_07 AC 149 ms
13,440 KB
testcase_08 AC 165 ms
13,440 KB
testcase_09 AC 241 ms
20,492 KB
testcase_10 AC 118 ms
20,560 KB
権限があれば一括ダウンロードができます

ソースコード

diff #

#include <cstdio>
#include <string>
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <vector>
#include <set>
#include <map>
#include <queue>
#include <stack>
#include <list>
#include <iterator>
#include <cassert>
#include <numeric>
#include <functional>
//#include <numeric>
#pragma warning(disable:4996) 
 
typedef long long ll;
typedef unsigned long long ull;
#define MIN(a, b) ((a)>(b)? (b): (a))
#define MAX(a, b) ((a)<(b)? (b): (a))
#define LINF  9223300000000000000
#define LINF2 1223300000000000000
#define INF 2140000000
//const long long MOD = 1000000007;
const long long MOD = 998244353;

using namespace std;


class UF
{
private:
    int         num;
    vector<int> par;
    vector<int> siz;
    vector<ll>  wt;   // diff from par
public:
    vector<ll> parval;   // val of root
public:
    UF(int n): num(n) {
        par.resize(n);
        siz.resize(n);
        wt.resize(n);
        parval.resize(n);
        int i;
        for(i=0; i<n; i++) {
            par[i]=i; siz[i]=1;
            wt[i]=0;
            parval[i]=0;
        }
    }
	
    int find(int x) {
        int p=par[x];
        if(x==p) {
            return x;
        }
        int q=find(p);
        ll diff=wt[q]-wt[p];
        wt[x]-=diff;
        return par[x]=q;
    }

    bool unite(int x, int y, ll diff) {     // val[y]=val[x]+diff
        int px=find(x);
        int py=find(y);
        ll diff0=wt[y]-wt[x];
        diff-=diff0;

        x=find(x);
        y=find(y);
        if(x==y) {
            return (wt[y]==wt[x]+diff);
        }
        if(siz[x]<siz[y]) {
            par[x]=y;
            siz[y]=siz[x]+siz[y];
            //val[y]=val[x]+val[y];
            wt[x]=-diff;
        }
        else {
            par[y]=x;
            siz[x]=siz[x]+siz[y];
            //val[x]=val[x]+val[y];
            wt[y]=diff;
        }
        return true;
    }

    bool same(int x, int y) {
        return find(x)==find(y);
    }

    int size(int x) {
        return siz[find(x)];
    }

    ll GetWt(int x) {
        int p=find(x);
        return wt[x];
    }

    int ngroup() {
    //int ngroup( int& ans ) {
        int count=0;
        int i;
        for(i=0; i<num; i++) {
            if(par[i]==i) {
                count++;
                //ans += (val[i]? siz[i]: siz[i]-1);
            }
        }
        return count;
    }

    void print() {
        int i;
        printf("par: "); for(i=0; i<num; i++) printf("%d ", par[i]); printf("\n");
        printf("wt:  "); for(i=0; i<num; i++) printf("%lld ", GetWt(i)); printf("\n");
    }
};


void solve()
{
    int n, Q;
    scanf("%d%d", &n, &Q);
    vector<int> t(Q), a(Q), b(Q);

    UF uf(n);
    int i;
    for (i = 0; i < Q; i++) {
        scanf("%d%d%d", &t[i], &a[i], &b[i]); a[i]--;
        if(t[i]==1) b[i]--;

        if (t[i] == 1) {
            if (uf.same(a[i], b[i])) {
                continue;
            }
            int pa = uf.find(a[i]);
            ll vala = uf.parval[pa] + uf.GetWt(a[i]);
            int pb = uf.find(b[i]);
            ll valb = uf.parval[pb] + uf.GetWt(b[i]);

            uf.unite(a[i], b[i], valb - vala);
        }
        else if (t[i] == 2) {
            int pa = uf.find(a[i]);
            uf.parval[pa] += b[i];
        }
        else {
            int pa = uf.find(a[i]);
            ll vala = uf.parval[pa] + uf.GetWt(a[i]);
            printf("%lld\n", vala);
        }
    }

    return;
}


int main(int argc, char* argv[])
{
#if 1
    solve();
#else
    int T;
    scanf("%d", &T);
    int t;
    for(t=0; t<T; t++) {
        //printf("Case #%d: ", t+1);
        solve();
    }
#endif
    return 0;
}
0