Solution Notes (Nathan Pinsker): The first thing we may notice is that we need an efficient way of determining
whether a string of parentheses is balanced. Some insight yields the following
procedure: Let a '(' represent 1 and a ')' represent -1. A substring of
parentheses is clearly balanced only if the sum of its corresponding numbers is
0, but we also have the necessary condition that no prefix of this string has
sum less than 0.
Using this insight, we can derive almost the desired solution: we sweep through
the
array left-to-right. At each step, we record the prefix sum of the parentheses
we have seen so far in each of our K arrays. If we have seen the exact same
combination of prefix sums before, then we have a possible match -- but we have
to make sure that none of the prefix sums in between these two locations are
less than those that we are checking.
This problem can be solved individually for each array, and rephrased as
"given elements in an array, for each element, find the closest element to its
left which is less than it", a fairly
well-known problem that can be solved in linear time. Once we have this
element (say it's at index i), we know that our balanced parenthesis string
cannot extend beyond index i. This allows us to determine whether the balanced
parenthesis string we have found is actually valid.
Some care is still required in handling cases where the correct balanced string
may comprise multiple sets of balanced parentheses. For example, consider the
strings "()(())()" and "()()()()". To handle this, when we find an end
parenthesis that corresponds to some number of balanced sets, we simply mark
how many sets it ends (the 6th parenthesis in the above string would end 2 sets
of balanced parentheses: the strings with indices 2-5 and the strings with
indices 0-5). Then, when we process the 8th parenthesis, we will remember this
information, and mark the 8th parenthesis as representing 3 sets.
Mark Gordon's concise solution is below. Note his use of the array R and
variable lft to solve the problem discussed above.
#include <iostream>
#include <vector>
#include <map>
#include <stdio.h>
using namespace std;
int main() {
freopen("cbs.in", "r", stdin);
freopen("cbs.out", "w", stdout);
int N, M; cin >> N >> M;
vector<string> A(N);
for(int i = 0; i < N; i++) cin >> A[i];
int res = 0;
vector<int> L(N, M);
vector<vector<int> > R(N, vector<int>(2 * M, M));
map<vector<int>, pair<int, int> > mp;
for(int i = 0; i < N; i++) R[i][M] = 0;
mp[L] = make_pair(0, 1);
for(int i = 0; i < M; i++) {
int lft = 0;
for(int j = 0; j < N; j++) {
if(A[j][i] == '(') {
R[j][++L[j]] = i + 1;
} else {
--L[j];
R[j][L[j]] = min(R[j][L[j]], i + 1);
}
lft = max(lft, R[j][L[j]]);
}
if(lft == M) continue;
pair<int, int>& dat = mp[L];
if(dat.first == lft) {
res += dat.second++;
} else {
dat = make_pair(lft, 1);
}
}
cout << res << endl;
}