Analysis: Crowded Cows, by Brian Dean


There are several ways to solve this problem in O(N log N) time using a "sweep line" approach. As shown in Mark Gordon's code below, one such solution is to sort the cows by position, then to scan through this ordering by maintaining the heights of all cows in the range x-d to x, and all cows in the range x to x+d, where x is the position of the current cow during our scan. To maintain the heights of the cows in these two sliding windows, we can use either a priority queue or a set (i.e., a balanced binary search tree). As we visit each cow, we can test if it is crowded by querying for the maximum height in both windows; if each maximum is more than twice the height of the current cow, the cow is crowded.

Alternatively, we can scan the cows in decreasing order of height, using a pair of sweep lines that move in lock step so that the upper sweep line is always at twice the height as the lower sweep line. Whenever the upper sweep line visits a cow, the position of that cow is inserted into a set data structure (i.e., a balanced binary search tree). When the lower sweep line visits a cow (say, at position x), we query this structure for the positions immediately preceding and following x (in an STL set, for example, we could use the lower_bound method to do this). The data structure contains the positions of all cows at least twice the height of the current cow, so if the predecessor and successor based on her position x are within the range x-d to x+d, then the current cow is crowded.

#include <iostream>
#include <vector>
#include <algorithm>
#include <set>
#include <cstdio>

using namespace std;

int main() {
  freopen("crowded.in", "r", stdin);
  freopen("crowded.out", "w", stdout);

  int N, D;
  cin >> N >> D;

  vector<pair<int, int> > A(N);
  for(int i = 0; i < N; i++) 
    cin >> A[i].first >> A[i].second;

  sort(A.begin(), A.end());

  int result = 0;
  multiset<int> X, Y;
  int j = 0, k = 0;
  for(int i = 0; i < N; i++) {
    while(k < N && A[k].first <= A[i].first + D) {
      Y.insert(A[k++].second);
    }
    while(A[j].first + D < A[i].first) {
      X.erase(X.find(A[j++].second));
    }
    X.insert(A[i].second);

    if (*--X.end() >= 2 * A[i].second &&
        *--Y.end() >= 2 * A[i].second) {
      result++;
    }

    Y.erase(Y.find(A[i].second));
  }

  cout << result << endl;
  return 0;
}