题目大意
题目描述
给定两个以二进制字符串形式表示的非负整数 s 和 t。
请计算有多少组整数对 (a,b) 满足:
-
取值范围:a 和 b 均在区间 [s,t] 内。
-
满足等式:a×b=(a or b)×(a and b)
其中 or 表示按位或运算,and 表示按位与运算。
最终答案需要对 998244353 取模。
输入格式
第一行包含一个整数 Q(1≤Q≤20),表示测试用例的数量。
接下来 Q 行,每行包含两个 01 字符串 S 和 T,分别表示非负整数 s 和 t 的二进制形式。
数据保证 s 和 t 的二进制形式无前导零,且满足 s≤t,1≤∣S∣,∣T∣≤5×105。
输出格式
对于每个测试用例,输出一行一个非负整数,表示满足条件的解的数量对 998244353 取模后的值。
样例数据
输入
1 2 3 4 5 6 7
| 6 11 1000 1000 1001 0 100 11 111 10 111 0 11111
|
输出
样例解释
原题未提供具体的样例解释内容,此处按要求对样例数据的完整输入输出进行展示。
思路讲解
这道题目的数的规律我们很容易从这个打表看出来。
我们下面只打印了有序对:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
| [-] FAILURE: RUNTIME ERROR i = [0] j = [1] bitset<10>(i) = [0000000000] bitset<10>(j) = [0000000001]
----------- i = [0] j = [2] bitset<10>(i) = [0000000000] bitset<10>(j) = [0000000010]
----------- i = [0] j = [3] bitset<10>(i) = [0000000000] bitset<10>(j) = [0000000011]
----------- i = [0] j = [4] bitset<10>(i) = [0000000000] bitset<10>(j) = [0000000100]
----------- i = [1] j = [3] bitset<10>(i) = [0000000001] bitset<10>(j) = [0000000011]
----------- i = [2] j = [3] bitset<10>(i) = [0000000010] bitset<10>(j) = [0000000011]
----------- cnt = [6]
|
不难看出,我们所要求的是:

当然,如果你要推出这个东西的话,你需要知道一个定理:

对于任意两个非负整数 a 和 b,我们设 x=a or b,y=a and b。
在位运算中,有一个非常经典的定律,即它们的和永远保持不变:
x+y=a+b
为什么? 我们可以单独看某一个二进制位:
-
如果 a 和 b 这一位都是 0,那么 x 是 0,y 也是 0。(0+0=0+0)
-
如果 a 和 b 这一位都是 1,那么 x 是 1,y 也是 1。(1+1=1+1)
-
如果 a 和 b 这一位是一个 1 一个 0,那么 x (OR) 会拿到这个 1,y (AND) 是 0。(1+0=1+0)
可以看出,无论哪种情况,进位和本质的值都不会丢失,所以 a or b 加上 a and b 必定等于 a+b。
接着,我们把这两个式子联立,可以得到:

那么不难得到:

不过更难的应该是更进一步的这个求解。
HDU - 2089- 不要62
这道题目和普通的数位 dp 最大的不同的就是,你一定要同时考虑这个低位和高位。

具体而言,为什么相减会出现问题呢?

那么我们上面讲了不能这么做,然后下面我们来讲讲应该怎么做。

然后下面是一种比较 shi 的实现方式,因为就是分类讨论。
if else 的这个实现方式需要人工判断转移,非常麻烦,而且容易写错
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
| void make_s_t_same_len(string &s, const string &t) { ll lenS = SZ(s), lenT = SZ(t); s = string(lenT - lenS, '0') + s; }
auto execute_dp(const string &l, const string &r) { vector<vector<ll> > dp(2, vector<ll>(2)); for (int i = 0; i < SZ(l); ++i) { ll r_bit = r[i] - '0'; ll l_bit = l[i] - '0'; if (i == 0) { if (r_bit && !l_bit) { dp[0][1] = 1; dp[0][0] = 1; dp[1][0] = 1; } else if (l_bit && r_bit) { dp[0][0] = 1; } else { dp[0][0] = 1; } continue; } vector<vector<ll> > ndp(2, vector<ll>(2)); if (r_bit == 0 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] ; ndp[1][0] = dp[1][0] * 2; ndp[1][1] = dp[1][1] * 3 + dp[1][0]; } else if (r_bit == 0 && l_bit == 1) { ndp[0][0] = 0; ndp[0][1] = dp[0][1]; ndp[1][0] = dp[1][0]; ndp[1][1] = dp[1][1] * 3; } else if (r_bit == 1 && l_bit == 1) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2; ndp[1][0] = dp[1][0] ; ndp[1][1] = dp[1][1] * 3 + dp[0][1]; } else if (r_bit == 1 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2 + dp[0][0]; ndp[1][0] = dp[1][0] * 2 + dp[0][0]; ndp[1][1] = dp[1][1] * 3 + dp[0][1] + dp[1][0]; } ndp[0][0] %= mod; ndp[1][1] %= mod; ndp[1][0] %= mod; ndp[0][1] %= mod; swap(dp, ndp); } return dp; }
ll get_val_from_str(const string &s) { ll res = 0; ll pow2 = 1; for (int i = SZ(s) - 1; i >= 0; --i) { ll val = s[i] - '0'; res += pow2 * val; res %= mod; pow2 *= 2; pow2 %= mod; } return res; }
ll gen_ans(const auto &dp, ll l, ll r) { ll ans = (dp[0][0] + dp[1][1] + dp[0][1] + dp[1][0]) % mod; ans = 2 * ans - (r - l + 1); ans %= mod; if (ans < 0) { ans += mod; } return ans; }
void Solve() { string s, t; cin >> s >> t; make_s_t_same_len(s, t); auto dp = execute_dp(s, t); cout << gen_ans(dp, get_val_from_str(s), get_val_from_str(t)) << "\n"; }
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| constexpr int moves[][2] = {{0, 0}, {1, 0}, {1, 1}}; auto execute_dp(const string &l, const string &r) { array<array<ll, 2>, 2> dp = {}; dp[0][0] = 1; for (int i = 0; i < (int) l.size(); ++i) { int r_bit = r[i] - '0', l_bit = l[i] - '0'; array<array<ll, 2>, 2> ndp = {}; for (auto &[a_r,b_l]: moves) { for (int status_r = 0; status_r <= 1; ++status_r) { for (int status_l = 0; status_l <= 1; ++status_l) { if (status_r == 0 && a_r > r_bit) { continue; } if (status_l == 0 && b_l < l_bit) { continue; } ll to_r = status_r, to_l = status_l; if (a_r != r_bit) { to_r = 1; } if (b_l != l_bit) { to_l = 1; } ndp[to_r][to_l] += dp[status_r][status_l]; ndp[to_r][to_l] %= mod; } } } swap(dp, ndp); } return dp; }
|
在知道 dp 数组以后,答案像下面这样子求:
1 2 3 4 5 6 7 8 9
| ll gen_ans(const auto &dp, ll l, ll r) { ll ans = (dp[0][0] + dp[1][1] + dp[0][1] + dp[1][0]) % mod; ans = 2 * ans - (r - l + 1); ans %= mod; if (ans < 0) { ans += mod; } return ans; }
|
AC代码
AC
https://acm.hdu.edu.cn/contest/view-code?cid=1198&rid=17914
源代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#include <bits/stdc++.h> #define all(vec) vec.begin(),vec.end() #define lson(o) (o<<1) #define rson(o) (o<<1|1) #define SZ(a) ((long long) a.size()) #define debug(var) cerr << #var <<" = ["<<var<<"]"<<"\n"; #define debug1d(a) \ cerr << #a << " = ["; \ for (int i = 0; i < (int)(a).size(); i++) \ cerr << (i ? ", " : "") << a[i]; \ cerr << "]\n"; #define debug2d(a) \ cerr << #a << " = [\n"; \ for (int i = 0; i < (int)(a).size(); i++) \ { \ cerr << " ["; \ for (int j = 0; j < (int)(a[i]).size(); j++) \ cerr << (j ? ", " : "") << a[i][j]; \ cerr << "]\n"; \ } \ cerr << "]\n"; #define cend cerr<<"\n-----------\n" #define fsp(x) fixed<<setprecision(x)
using namespace std;
using ll = long long; using ull = unsigned long long; using DB = double; using i128 = __int128; using CD = complex<double>;
static constexpr ll MAXN = (ll) 1e6 + 10, INF = (1ll << 61) - 1; static constexpr ll mod = 998244353; static constexpr double eps = 1e-8; const double PI = acos(-1.0);
ll lT, testcase;
void make_s_t_same_len(string &s, const string &t) { ll lenS = SZ(s), lenT = SZ(t); s = string(lenT - lenS, '0') + s; }
auto execute_dp(const string &l, const string &r) { vector<vector<ll> > dp(2, vector<ll>(2)); for (int i = 0; i < SZ(l); ++i) { ll r_bit = r[i] - '0'; ll l_bit = l[i] - '0'; if (i == 0) { if (r_bit && !l_bit) { dp[0][1] = 1; dp[0][0] = 1; dp[1][0] = 1; } else if (l_bit && r_bit) { dp[0][0] = 1; } else { dp[0][0] = 1; } continue; } vector<vector<ll> > ndp(2, vector<ll>(2)); if (r_bit == 0 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] ; ndp[1][0] = dp[1][0] * 2; ndp[1][1] = dp[1][1] * 3 + dp[1][0]; } else if (r_bit == 0 && l_bit == 1) { ndp[0][0] = 0; ndp[0][1] = dp[0][1]; ndp[1][0] = dp[1][0]; ndp[1][1] = dp[1][1] * 3; } else if (r_bit == 1 && l_bit == 1) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2; ndp[1][0] = dp[1][0] ; ndp[1][1] = dp[1][1] * 3 + dp[0][1]; } else if (r_bit == 1 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2 + dp[0][0]; ndp[1][0] = dp[1][0] * 2 + dp[0][0]; ndp[1][1] = dp[1][1] * 3 + dp[0][1] + dp[1][0]; } ndp[0][0] %= mod; ndp[1][1] %= mod; ndp[1][0] %= mod; ndp[0][1] %= mod; swap(dp, ndp); } return dp; }
ll get_val_from_str(const string &s) { ll res = 0; ll pow2 = 1; for (int i = SZ(s) - 1; i >= 0; --i) { ll val = s[i] - '0'; res += pow2 * val; res %= mod; pow2 *= 2; pow2 %= mod; } return res; }
ll gen_ans(const auto &dp, ll l, ll r) { ll ans = (dp[0][0] + dp[1][1] + dp[0][1] + dp[1][0]) % mod; #ifdef LOCAL debug(l); debug(r); debug2d(dp); debug(ans); #endif ans = 2 * ans - (r - l + 1); ans %= mod; if (ans < 0) { ans += mod; } return ans; }
void Solve() { string s, t; cin >> s >> t; make_s_t_same_len(s, t); auto dp = execute_dp(s, t); cout << gen_ans(dp, get_val_from_str(s), get_val_from_str(t)) << "\n"; }
signed main() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); #ifdef LOCAL cout.setf(ios::unitbuf); #endif
cin >> lT; for (testcase = 1; testcase <= lT; ++testcase) Solve(); return 0; }
|
AC
https://acm.hdu.edu.cn/contest/view-code?cid=1198&rid=17962
采用了更好实现方式的代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#include <bits/stdc++.h> #define all(vec) vec.begin(),vec.end() #define lson(o) (o<<1) #define rson(o) (o<<1|1) #define SZ(a) ((long long) a.size()) #define debug(var) cerr << #var <<" = ["<<var<<"]"<<"\n"; #define debug1d(a) \ cerr << #a << " = ["; \ for (int i = 0; i < (int)(a).size(); i++) \ cerr << (i ? ", " : "") << a[i]; \ cerr << "]\n"; #define debug2d(a) \ cerr << #a << " = [\n"; \ for (int i = 0; i < (int)(a).size(); i++) \ { \ cerr << " ["; \ for (int j = 0; j < (int)(a[i]).size(); j++) \ cerr << (j ? ", " : "") << a[i][j]; \ cerr << "]\n"; \ } \ cerr << "]\n"; #define cend cerr<<"\n-----------\n" #define fsp(x) fixed<<setprecision(x)
using namespace std;
using ll = long long; using ull = unsigned long long; using DB = double; using i128 = __int128; using CD = complex<double>;
static constexpr ll MAXN = (ll) 1e6 + 10, INF = (1ll << 61) - 1; static constexpr ll mod = 998244353; static constexpr double eps = 1e-8; const double PI = acos(-1.0);
ll lT, testcase;
void make_s_t_same_len(string &s, const string &t) { ll lenS = SZ(s), lenT = SZ(t); s = string(lenT - lenS, '0') + s; }
auto execute_dp(const string &l, const string &r) { vector<vector<ll> > dp(2, vector<ll>(2)); for (int i = 0; i < SZ(l); ++i) { ll r_bit = r[i] - '0'; ll l_bit = l[i] - '0'; if (i == 0) { if (r_bit && !l_bit) { dp[0][1] = 1; dp[0][0] = 1; dp[1][0] = 1; } else if (l_bit && r_bit) { dp[0][0] = 1; } else { dp[0][0] = 1; } continue; } vector<vector<ll> > ndp(2, vector<ll>(2)); if (r_bit == 0 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] ; ndp[1][0] = dp[1][0] * 2; ndp[1][1] = dp[1][1] * 3 + dp[1][0]; } else if (r_bit == 0 && l_bit == 1) { ndp[0][0] = 0; ndp[0][1] = dp[0][1]; ndp[1][0] = dp[1][0]; ndp[1][1] = dp[1][1] * 3; } else if (r_bit == 1 && l_bit == 1) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2; ndp[1][0] = dp[1][0] ; ndp[1][1] = dp[1][1] * 3 + dp[0][1]; } else if (r_bit == 1 && l_bit == 0) { ndp[0][0] = dp[0][0]; ndp[0][1] = dp[0][1] * 2 + dp[0][0]; ndp[1][0] = dp[1][0] * 2 + dp[0][0]; ndp[1][1] = dp[1][1] * 3 + dp[0][1] + dp[1][0]; } ndp[0][0] %= mod; ndp[1][1] %= mod; ndp[1][0] %= mod; ndp[0][1] %= mod; swap(dp, ndp); } return dp; }
ll get_val_from_str(const string &s) { ll res = 0; ll pow2 = 1; for (int i = SZ(s) - 1; i >= 0; --i) { ll val = s[i] - '0'; res += pow2 * val; res %= mod; pow2 *= 2; pow2 %= mod; } return res; }
ll gen_ans(const auto &dp, ll l, ll r) { ll ans = (dp[0][0] + dp[1][1] + dp[0][1] + dp[1][0]) % mod; #ifdef LOCAL debug(l); debug(r); debug2d(dp); debug(ans); #endif ans = 2 * ans - (r - l + 1); ans %= mod; if (ans < 0) { ans += mod; } return ans; }
void Solve() { string s, t; cin >> s >> t; make_s_t_same_len(s, t); auto dp = execute_dp(s, t); cout << gen_ans(dp, get_val_from_str(s), get_val_from_str(t)) << "\n"; }
signed main() { ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr); #ifdef LOCAL cout.setf(ios::unitbuf); #endif
cin >> lT; for (testcase = 1; testcase <= lT; ++testcase) Solve(); return 0; }
|
心路历程(WA,TLE,MLE……)