Topics
We are given N diamonds (with sizes) and a parameter K. Our goal is to split the diamonds between two display cases so that in each case the difference between the smallest and largest diamond is at most K, and we maximize the total number of diamonds displayed.
Idea
- Sorting is Optimal: If we sort the diamonds, any valid set for a case becomes a contiguous segment. Moreover, if an optimal solution were interleaved, we could always rearrange it into two non–overlapping segments without losing any diamonds.
- Contiguous Segments: With the diamonds sorted, the problem reduces to choosing two non–overlapping contiguous segments such that each segment is “valid” (i.e. the difference between its minimum and maximum sizes is ⇐ K. In the end we pick the segments having maximum size in total.
Approach
- Sort the Diamonds: Sort the array of diamond sizes.
- Two Pointers to compute valid Segments: For each starting index , determine the maximum number of consecutive diamonds (say ) that can be grouped together so that:
- This can be done with a two–pointer (sliding window) technique in overall time.
- Precompute best segment starting at or after an index:
- Define an array
right_max
where: - This helps answer quickly: “If the first case ends at index , what is the best (largest) valid segment starting at or after index for the second case?”
- Define an array
- Combine Two Cases: For every possible starting index for the first segment:
- The first case can include diamonds (from to ).
- The second case can then use the best segment starting at index , which is given by
right_max[i + s[i]]
. - The candidate is: and we take the maximum over all .
Time Complexity: Sorting: + 2 Pointers + Precomputation: =
Space Complexity:
Code
def main():
with open("diamond.in", "r") as fin:
n, k = map(int, fin.readline().split())
diamonds = [int(fin.readline()) for _ in range(n)]
diamonds.sort()
s = []
j = 0
for i in range(n):
curr = diamonds[i]
while j < n and diamonds[j] <= curr + k:
j += 1
s.append(j - i)
# Precompute right_max
right_max = [0] * (n + 1)
for i in range(n - 1, -1, -1):
right_max[i] = max(s[i], right_max[i + 1])
max_sum = 0
for i in range(n):
end = i + s[i]
current_sum = s[i] + (right_max[end] if end < n else 0)
if current_sum > max_sum:
max_sum = current_sum
with open("diamond.out", "w") as fout:
fout.write(str(max_sum) + "\n")
if __name__ == "__main__":
main()