Topics
Sam has a list initially containing a single element . He repeatedly performs operations where he removes any element and replaces it with the elements , , in that order. The operations continue until all elements in the list are either 0 or 1.
We need to find the total number of 1s in the range from position to position (1-indexed) in the final list.
Idea
We can define the transformation function recursively:
- for , where represents concatenation
We need to precompute two important properties for each value :
- The length of the final list for
- The number of 1s in the final list for
For a number , if we define:
- = length of the final list for
- = number of 1s in the final list for
Then we have:
- for
- for
- ,
- ,
Range Query
To count the number of 1s in a range , we can use a helper function to count the number of 1s up to a given index (like prefix sums concept):
Our helper function, num_ones_until(n, idx)
, computes the same, recursively as follows:
- If
idx
equals the length of the final list for , return the total number of 1s for - If
idx
equals the length of the final list for + 1, return the number of 1s in plus - If
idx
is less than or equal to the length of the final list for , recurse with and the sameidx
- Otherwise, return the number of 1s in plus plus the number of 1s in the remaining part
Time Complexity
- Precomputation: since we have a recursive function that divides by 2 each time
- Query: for a similar reason
- Overall:
Space Complexity
for storing the precomputed values and the recursion stack.
Code
map<ll, pll> ones;
// Precomputes the length and number of ones for each value
void precomp(ll n) {
if (n <= 1) {
ones[n] = make_pair(1, n);
return;
}
precomp(n / 2);
int mid_bit = n % 2;
ll sz = 1 + 2 * ones[n / 2].first; // Total length
ll one_sz = mid_bit + 2 * ones[n / 2].second; // Total number of ones
ones[n] = make_pair(sz, one_sz);
}
// Returns number of ones up to index idx in the final list for n
ll num_ones_until(ll n, ll idx) {
if (idx == ones[n].first)
return ones[n].second;
if (idx == ones[n / 2].first + 1)
return n % 2 + ones[n / 2].second;
if (idx <= ones[n / 2].first)
return num_ones_until(n / 2, idx);
return ones[n / 2].second + n % 2 +
num_ones_until(n / 2, idx - ones[n / 2].first - 1);
}
void solve() {
ll n, l, r;
cin >> n >> l >> r;
precomp(n);
ll ans = num_ones_until(n, r);
if (l > 1)
ans -= num_ones_until(n, l - 1);
cout << ans << endl;
}