First, it suffices to count the number of friendly crossings, then subtract this number from the total number of crossings.
Let $r_a[i]$ denote the position of the $i$th cow in the first line and $r_b[i]$ denote the position of the $i$th cow in the second line.
Let $s$ be a sequence satisfying $s[r_a[i]]$ = $r_b[i]$. Then, the number of inversions in $s$ is the total number of crossings, which can be computed in $O(n \log n)$ time.
A pair of cows $i,j$ are intersecting if $(r_a[i] < r_a[j])$ is different from $(r_b[i] < r_b[j])$
So, we want to count the number of unordered pairs such that $|i-j| <= k$ and $(r_a[i] < r_a[j]) \neq (r_b[i] < r_b[j])$.
We can change this to count the number of ordered pairs $i,j$ such that $|i-j| <= k$ and $(r_a[i] < r_a[j])$ and $(r_b[i] > r_b[j])$.
Let's consider $n$ points where the $i$th point has coordinates $(r_a[i], r_b[i])$.
Let's fix a cow $i$. We want to count the number of other cows $j$ such that $i-k <= j <= i+k$ and $j \neq i$ such that the $j$th point lies in the top-left corner of the $i$th point. Rather than doing arbitrary ranges, we can change this so it's a prefix by counting the number of points $j$ in the top-left corner where $j <= i+k$ and subtract the number of points $j$ in the top-left corner where $j <= i-k-1$.
So, to restate our problem, we have $n$ points, and we want to support two operations:
- Insert a point. ($n$ times)
- Given a point, count the number of points in the top-left of the given point. ($2n$ times)
This suggests a 2D segment tree. Note that we can't create the full tree, as that would require $n^2$ memory, but instead we can create an implicit tree and only expand nodes when we need to. This requires space proportional to the time required, which is $O(n \log^2 n)$.
For a slightly faster solution, we can notice that the points are fixed, and there is at most one point per x-coordinate. Let's create a merge sort tree on the array $s$. More specifically, each node in this segment tree will store the sorted values in its own range.
We can change "insert a point" to "turn a point on". A count can be done by using a binary indexed tree within a node to see how many have been turned on.
This implementation only requires $O(n \log n)$ space and $O(n \log^2 n)$ time with a better constant factor.
You can see my java code for some more details on the implementation of this approach:
import java.io.OutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.io.PrintWriter; import java.io.BufferedWriter; import java.io.IOException; import java.io.InputStreamReader; import java.util.ArrayList; import java.util.StringTokenizer; import java.io.Writer; import java.io.OutputStreamWriter; import java.io.BufferedReader; import java.io.InputStream; /** * Built using CHelper plug-in * Actual solution is at the top */ public class Main { public static void main(String[] args) { InputStream inputStream = System.in; OutputStream outputStream = System.out; InputReader in = new InputReader(inputStream); OutputWriter out = new OutputWriter(outputStream); friendcross solver = new friendcross(); solver.solve(1, in, out); out.close(); } static class friendcross { public int[] arr; public int[] brr; public static int[] seq; public void solve(int testNumber, InputReader in, OutputWriter out) { int n = in.nextInt(), k = in.nextInt(); arr = in.readIntArray(n); brr = in.readIntArray(n); int[] ra = new int[n + 1]; int[] rb = new int[n + 1]; for (int i = 0; i < n; i++) { ra[arr[i]] = i + 1; rb[brr[i]] = i + 1; } seq = new int[n + 1]; for (int i = 1; i <= n; i++) { seq[ra[i]] = rb[i]; } friendcross.SegmentTree root = new friendcross.SegmentTree(1, n); ArrayList<friendcross.Event>[] events = new ArrayList[n + 1]; for (int i = 0; i <= n; i++) events[i] = new ArrayList<>(); for (int i = 1; i <= n; i++) { int up = Math.min(n, i + k); int down = Math.max(0, i - k - 1); events[up].add(new friendcross.Event(i, +1)); events[down].add(new friendcross.Event(i, -1)); } long tinv = 0; BIT x = new BIT(n); for (int i = 1; i <= n; i++) { x.update(seq[i], +1); tinv += i - x.query(seq[i]); } long res = 0; for (int cow = 1; cow <= n; cow++) { root.update(ra[cow], +1); for (friendcross.Event e : events[cow]) res += e.sign * root.query(1, ra[e.cow], rb[e.cow]); } out.println(tinv - res); } static class Event { public int cow; public int sign; public Event(int cow, int sign) { this.cow = cow; this.sign = sign; } } static class SegmentTree { public int[] arr; public int[] pl; public int[] pr; public BIT bit; public int start; public int end; public friendcross.SegmentTree lchild; public friendcross.SegmentTree rchild; public SegmentTree(int start, int end) { this.start = start; this.end = end; arr = new int[end - start + 2]; if (start == end) { arr[1] = seq[start]; } else { int mid = (start + end) >> 1; lchild = new friendcross.SegmentTree(start, mid); rchild = new friendcross.SegmentTree(mid + 1, end); pl = new int[lchild.arr.length]; pr = new int[rchild.arr.length]; int lidx = 1, ridx = 1; int idx = 1; int[] larr = lchild.arr, rarr = rchild.arr; while (lidx < larr.length && ridx < rarr.length) { if (larr[lidx] < rarr[ridx]) { pl[lidx] = idx; arr[idx++] = larr[lidx++]; } else { pr[ridx] = idx; arr[idx++] = rarr[ridx++]; } } while (lidx < larr.length) { pl[lidx] = idx; arr[idx++] = larr[lidx++]; } while (ridx < rarr.length) { pr[ridx] = idx; arr[idx++] = rarr[ridx++]; } } bit = new BIT(end - start + 2); } public int query(int s, int e, int k) { if (start == s && end == e) { if (k < arr[1]) return bit.count; int lo = 1, hi = arr.length - 1; while (lo < hi) { int mid = (lo + hi + 1) / 2; if (arr[mid] > k) hi = mid - 1; else lo = mid; } return bit.count - bit.query(lo); } int mid = (start + end) >> 1; if (mid >= e) return lchild.query(s, e, k); else if (mid < s) return rchild.query(s, e, k); else return lchild.query(s, mid, k) + rchild.query(mid + 1, e, k); } public int update(int p, int val) { if (start == p && end == p) { bit.update(1, +1); return 1; } int mid = (start + end) >> 1; int apos = -1; if (mid >= p) apos = pl[lchild.update(p, val)]; else apos = pr[rchild.update(p, val)]; bit.update(apos, +1); return apos; } } } static class InputReader { public BufferedReader reader; public StringTokenizer tokenizer; public InputReader(InputStream stream) { reader = new BufferedReader(new InputStreamReader(stream), 32768); tokenizer = null; } public String next() { while (tokenizer == null || !tokenizer.hasMoreTokens()) { try { tokenizer = new StringTokenizer(reader.readLine()); } catch (IOException e) { throw new RuntimeException(e); } } return tokenizer.nextToken(); } public int[] readIntArray(int tokens) { int[] ret = new int[tokens]; for (int i = 0; i < tokens; i++) { ret[i] = nextInt(); } return ret; } public int nextInt() { return Integer.parseInt(next()); } } static class BIT { private int[] tree; private int N; public int count; public BIT(int N) { this.N = N; this.tree = new int[N + 1]; this.count = 0; } public int query(int K) { int sum = 0; for (int i = K; i > 0; i -= (i & -i)) sum += tree[i]; return sum; } public void update(int K, int val) { this.count += val; for (int i = K; i <= N; i += (i & -i)) tree[i] += val; } } static class OutputWriter { private final PrintWriter writer; public OutputWriter(OutputStream outputStream) { writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream))); } public OutputWriter(Writer writer) { this.writer = new PrintWriter(writer); } public void close() { writer.close(); } public void println(long i) { writer.println(i); } } }