CF368B Sereja and Suffixes

看到好像没人发树状数组的题解,于是就发一篇

题目大意

给出一个长度为$n$的序列$a$,给出$m$个查询$l$,对于每个查询输出$[l,n]$的区间内不同数的个数。

分析:

将查询按照$l$的大小排序,从大到小的遍历,每次将$>=$当前$l$的位置的$a[i]$全部加入树状数组,让树状数组记录每个数是否出现过,每次的答案就是查询树状数组的总和。

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include <cstdio>
#include <algorithm>
const int MAX = 100010;
#define lowbit(x) x & -x
using namespace std;
typedef pair<int, int> pa;
int n, m, ans[MAX], a[MAX], c[MAX];
pa l[MAX];
const int lim = 1e5 + 1;
//template
template<class type>type _max(type _, type __) { return _ > __ ? _ : __; }
template<class type>type _min(type _, type __) { return _ < __ ? _ : __; }
template<class type>type _abs(type __) { return __ < 0 ? -__ : __; }
template<class type>type _gcd(type __, type ___) { return (!___) ? __ : gcd(___, __ % ___); }
template<class type>type _mod(type __, type ____) { if (__ >= 0 && __ < ____)return __; __ %= ____; if (__ < 0)__ += ____; return __; }
template<class type>type _qpow(type __, type ___, type ____) { type ans = 1; for (; ___; ___ >>= 1, __ = _mod(__ * __, ____))if (___ & 1)ans = _mod(ans * __, ____); return ans; }
//char buf[1 << 21], *p1 = buf, *p2 = buf;
//inline int getc() {
// return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++;
//}
inline int r() {
register int x = 0; register char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar());
for (; ch >= '0' && ch <= '9'; ch = getchar()) x = (x << 1) + (x << 3) + (ch ^ 48);
return x;
}
inline double r_db() {
double s = 0.0; int f = 1; char c = getchar();
while (c < '0' || c > '9') { if (c == '-') f = -1; c = getchar(); }
while (c >= '0' && c <= '9') { s = s * 10 + (c ^ 48); c = getchar(); }
if (c == '.') { int k = 10; c = getchar(); while (c >= '0' && c <= '9') { s += (double)(c ^ 48) / k; k *= 10; c = getchar(); } }
return s;
}
//以上的template可以忽略
void add(int x) {
while (x <= lim) {
c[x]++;
x += lowbit(x);
}
}
int sum(int x) {
int ret = 0;
while (x) {
ret += c[x];
x -= lowbit(x);
}
return ret;
}
int main() {
while (~scanf("%d%d", &n, &m)) {
for (int i = 0; i < n; i++)
a[i] = r();
for (int i = 0; i < m; i++) {
l[i].first = r();
l[i].first--;
l[i].second = i;
}
sort(l, l + m);
int j = n - 1;
for (int i = m - 1; i >= 0; i--) {
while (j >= l[i].first) {
if (sum(a[j]) - sum(a[j] - 1) == 0)
add(a[j]);
j--;
}
ans[l[i].second] = sum(1e5);
}
for (int i = 0; i < m; i++)
printf("%d\n", ans[i]);
}
}

希望你们不要$ctrl+c$ $ctrl+v$