For each pair $(i,k)$ satisfying $i<k$ let $num[i][k]$ equal the number of $j$ such that $i<j<k$ and $A_i+A_j+A_k=0$. Then if $ans[i][k]$ denotes the answer for $(A_i,A_{i+1},\ldots,A_k)$, we can write
Now I'll describe a way to compute $num[i][i+1],\ldots, num[i][N]$ in $O(N)$ time. For each $k$ from $i+1,\ldots N$ in increasing order, consider a hashmap (unordered_map in C++) that allows you to query the number of occurrences of any integer among $A_{i+1},\ldots,A_{k-1}$. Then $num[i][k]$ is equal to the number of occurrences of $-A_i-A_k$ in this map. When $k$ is incremented by one the only change to the map is a single insertion. As hashmap operations run in $O(1)$ time, this solution runs in $O(N^2)$ time overall.
However, due to the high constant factor of hashmap, this solution does not earn full points. Because all entries of $A$ are in the range $[-10^6,10^6],$ we can replace the hashmap with an array of size $2\cdot 10^6+1,$ greatly improving the runtime.
import java.io.*; import java.util.*; public class threesum { public static void main(String[] args) throws IOException { BufferedReader in = new BufferedReader(new FileReader("threesum.in")); PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("threesum.out"))); String[] line = in.readLine().split(" "); int N = Integer.parseInt(line[0]); int Q = Integer.parseInt(line[1]); line = in.readLine().split(" "); int[] A = new int[N]; long[][] ans = new long[N][N]; for (int i = 0; i < N; ++i) A[i] = Integer.parseInt(line[i]); int[] z = new int[2000001]; for (int i = N-1; i >= 0; --i) { for (int j = i+1; j < N; ++j) { int ind = 1000000-A[i]-A[j]; if (ind >= 0 && ind <= 2000000) ans[i][j] = z[ind]; z[1000000+A[j]] ++; } for (int j = i+1; j < N; ++j) { z[1000000+A[j]] --; } } for (int i = N-1; i >= 0; --i) for (int j = i+1; j < N; ++j) ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1]; for (int i = 0; i < Q; ++i) { line = in.readLine().split(" "); int a = Integer.parseInt(line[0]); int b = Integer.parseInt(line[1]); out.println(ans[a-1][b-1]); } out.close(); } }
Of course, it was still possible to pass without replacing the hashmap by an array. Although this wasn't intended, I'll include two additional solutions for the sake of completeness.
In C++, gp_hash_table is somewhat faster than unordered_map, especially if you specify an initial capacity. See here for more information.
#include <bits/stdc++.h> using namespace std; void setIO(string name) { ios_base::sync_with_stdio(0); cin.tie(0); freopen((name+".in").c_str(),"r",stdin); freopen((name+".out").c_str(),"w",stdout); } #include <ext/pb_ds/assoc_container.hpp> // for gp_hash_table using namespace __gnu_pbds; int N,Q; long long ans[5000][5000]; vector<int> A; int main() { setIO("threesum"); cin >> N >> Q; A.resize(N); for (int i = 0; i < N; ++i) cin >> A[i]; for (int i = 0; i < N; ++i) { gp_hash_table<int,int> g({},{},{},{},{1<<13}); // initialize with capacity that is a power of 2 for (int j = i+1; j < N; ++j) { int res = -A[i]-A[j]; auto it = g.find(res); if (it != end(g)) ans[i][j] = it->second; g[A[j]] ++; } } for (int i = N-1; i >= 0; --i) for (int j = i+1; j < N; ++j) ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1]; for (int i = 0; i < Q; ++i) { int a,b; cin >> a >> b; cout << ans[a-1][b-1] << "\n"; } }
In Java, a hashmap solution passes if StreamTokenizer is used to take care of input, although it uses much more memory than I would expect. (If anyone knows how to reduce the memory usage, could you let me know?)
import java.io.*; import java.util.*; public class threesum { static StreamTokenizer in; static int nextInt() throws IOException { in.nextToken(); return (int)in.nval; } public static void main(String[] args) throws IOException { in = new StreamTokenizer(new BufferedReader(new FileReader("threesum.in"))); PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("threesum.out"))); int N = nextInt(); int Q = nextInt(); int[] A = new int[N]; long[][] ans = new long[N][N]; for (int i = 0; i < N; ++i) A[i] = nextInt(); Map<Integer,Integer> z = new HashMap<>(); for (int i = N-1; i >= 0; --i) { z.clear(); for (int j = i+1; j < N; ++j) { int ind = -A[i]-A[j]; ans[i][j] = z.getOrDefault(ind,0); z.put(A[j],z.getOrDefault(A[j],0)+1); } } for (int i = N-1; i >= 0; --i) for (int j = i+1; j < N; ++j) ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1]; for (int i = 0; i < Q; ++i) out.println(ans[nextInt()-1][nextInt()-1]); out.close(); } }