Solution Notes (Nathan Pinsker): The first thing we may notice is that we need an efficient way of determining whether a string of parentheses is balanced. Some insight yields the following procedure: Let a '(' represent 1 and a ')' represent -1. A substring of parentheses is clearly balanced only if the sum of its corresponding numbers is 0, but we also have the necessary condition that no prefix of this string has sum less than 0.

Using this insight, we can derive almost the desired solution: we sweep through the array left-to-right. At each step, we record the prefix sum of the parentheses we have seen so far in each of our K arrays. If we have seen the exact same combination of prefix sums before, then we have a possible match -- but we have to make sure that none of the prefix sums in between these two locations are less than those that we are checking.

This problem can be solved individually for each array, and rephrased as "given elements in an array, for each element, find the closest element to its left which is less than it", a fairly well-known problem that can be solved in linear time. Once we have this element (say it's at index i), we know that our balanced parenthesis string cannot extend beyond index i. This allows us to determine whether the balanced parenthesis string we have found is actually valid.

Some care is still required in handling cases where the correct balanced string may comprise multiple sets of balanced parentheses. For example, consider the strings "()(())()" and "()()()()". To handle this, when we find an end parenthesis that corresponds to some number of balanced sets, we simply mark how many sets it ends (the 6th parenthesis in the above string would end 2 sets of balanced parentheses: the strings with indices 2-5 and the strings with indices 0-5). Then, when we process the 8th parenthesis, we will remember this information, and mark the 8th parenthesis as representing 3 sets.

Mark Gordon's concise solution is below. Note his use of the array R and variable lft to solve the problem discussed above.


#include <iostream>
#include <vector>
#include <map>
#include <stdio.h>

using namespace std;

int main() {
  freopen("cbs.in", "r", stdin);
  freopen("cbs.out", "w", stdout);

  int N, M; cin >> N >> M;
  vector<string> A(N);
  for(int i = 0; i < N; i++) cin >> A[i];

  int res = 0;
  vector<int> L(N, M);
  vector<vector<int> > R(N, vector<int>(2 * M, M));
  map<vector<int>, pair<int, int> > mp;
  for(int i = 0; i < N; i++) R[i][M] = 0;
  
  mp[L] = make_pair(0, 1);
  for(int i = 0; i < M; i++) {
    int lft = 0;
    for(int j = 0; j < N; j++) {
      if(A[j][i] == '(') {
        R[j][++L[j]] = i + 1;
      } else {
        --L[j];
        R[j][L[j]] = min(R[j][L[j]], i + 1);
      }
      lft = max(lft, R[j][L[j]]);
    }
    if(lft == M) continue;

    pair<int, int>& dat = mp[L];
    if(dat.first == lft) {
      res += dat.second++;
    } else {
      dat = make_pair(lft, 1);
    }
  }
  cout << res << endl;
}