Topics

Sam has a list initially containing a single element . He repeatedly performs operations where he removes any element and replaces it with the elements , , in that order. The operations continue until all elements in the list are either 0 or 1.

We need to find the total number of 1s in the range from position to position (1-indexed) in the final list.

Idea

We can define the transformation function recursively:

  • for , where represents concatenation

We need to precompute two important properties for each value :

  1. The length of the final list for
  2. The number of 1s in the final list for

For a number , if we define:

  • = length of the final list for
  • = number of 1s in the final list for

Then we have:

  • for
  • for
  • ,
  • ,

Range Query

To count the number of 1s in a range , we can use a helper function to count the number of 1s up to a given index (like prefix sums concept):

Our helper function, num_ones_until(n, idx), computes the same, recursively as follows:

  1. If idx equals the length of the final list for , return the total number of 1s for
  2. If idx equals the length of the final list for + 1, return the number of 1s in plus
  3. If idx is less than or equal to the length of the final list for , recurse with and the same idx
  4. Otherwise, return the number of 1s in plus plus the number of 1s in the remaining part

Time Complexity

  • Precomputation: since we have a recursive function that divides by 2 each time
  • Query: for a similar reason
  • Overall:

Space Complexity

for storing the precomputed values and the recursion stack.

Code

map<ll, pll> ones;
 
// Precomputes the length and number of ones for each value
void precomp(ll n) {
  if (n <= 1) {
    ones[n] = make_pair(1, n);
    return;
  }
  precomp(n / 2);
  int mid_bit = n % 2;
  ll sz = 1 + 2 * ones[n / 2].first;    // Total length
  ll one_sz = mid_bit + 2 * ones[n / 2].second;  // Total number of ones
  ones[n] = make_pair(sz, one_sz);
}
 
// Returns number of ones up to index idx in the final list for n
ll num_ones_until(ll n, ll idx) {
  if (idx == ones[n].first)
    return ones[n].second;
  if (idx == ones[n / 2].first + 1)
    return n % 2 + ones[n / 2].second;
  if (idx <= ones[n / 2].first)
    return num_ones_until(n / 2, idx);
  return ones[n / 2].second + n % 2 +
         num_ones_until(n / 2, idx - ones[n / 2].first - 1);
}
 
void solve() {
  ll n, l, r;
  cin >> n >> l >> r;
  precomp(n);
  ll ans = num_ones_until(n, r);
  if (l > 1)
    ans -= num_ones_until(n, l - 1);
  cout << ans << endl;
}