(Analysis by Lewin Gan)

First, it suffices to count the number of friendly crossings, then subtract this number from the total number of crossings.

Let $r_a[i]$ denote the position of the $i$th cow in the first line and $r_b[i]$ denote the position of the $i$th cow in the second line.

Let $s$ be a sequence satisfying $s[r_a[i]]$ = $r_b[i]$. Then, the number of inversions in $s$ is the total number of crossings, which can be computed in $O(n \log n)$ time.

A pair of cows $i,j$ are intersecting if $(r_a[i] < r_a[j])$ is different from $(r_b[i] < r_b[j])$

So, we want to count the number of unordered pairs such that $|i-j| <= k$ and $(r_a[i] < r_a[j]) \neq (r_b[i] < r_b[j])$.

We can change this to count the number of ordered pairs $i,j$ such that $|i-j| <= k$ and $(r_a[i] < r_a[j])$ and $(r_b[i] > r_b[j])$.

Let's consider $n$ points where the $i$th point has coordinates $(r_a[i], r_b[i])$.

Let's fix a cow $i$. We want to count the number of other cows $j$ such that $i-k <= j <= i+k$ and $j \neq i$ such that the $j$th point lies in the top-left corner of the $i$th point. Rather than doing arbitrary ranges, we can change this so it's a prefix by counting the number of points $j$ in the top-left corner where $j <= i+k$ and subtract the number of points $j$ in the top-left corner where $j <= i-k-1$.

So, to restate our problem, we have $n$ points, and we want to support two operations:

- Insert a point. ($n$ times)

- Given a point, count the number of points in the top-left of the given point. ($2n$ times)

This suggests a 2D segment tree. Note that we can't create the full tree, as that would require $n^2$ memory, but instead we can create an implicit tree and only expand nodes when we need to. This requires space proportional to the time required, which is $O(n \log^2 n)$.

For a slightly faster solution, we can notice that the points are fixed, and there is at most one point per x-coordinate. Let's create a merge sort tree on the array $s$. More specifically, each node in this segment tree will store the sorted values in its own range.

We can change "insert a point" to "turn a point on". A count can be done by using a binary indexed tree within a node to see how many have been turned on.

This implementation only requires $O(n \log n)$ space and $O(n \log^2 n)$ time with a better constant factor.

You can see my java code for some more details on the implementation of this approach:

import java.io.OutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.StringTokenizer;
import java.io.Writer;
import java.io.OutputStreamWriter;
import java.io.BufferedReader;
import java.io.InputStream;

 * Built using CHelper plug-in
 * Actual solution is at the top
public class Main {
    public static void main(String[] args) {
        InputStream inputStream = System.in;
        OutputStream outputStream = System.out;
        InputReader in = new InputReader(inputStream);
        OutputWriter out = new OutputWriter(outputStream);
        friendcross solver = new friendcross();
        solver.solve(1, in, out);

    static class friendcross {
        public int[] arr;
        public int[] brr;
        public static int[] seq;

        public void solve(int testNumber, InputReader in, OutputWriter out) {
            int n = in.nextInt(), k = in.nextInt();
            arr = in.readIntArray(n);
            brr = in.readIntArray(n);
            int[] ra = new int[n + 1];
            int[] rb = new int[n + 1];
            for (int i = 0; i < n; i++) {
                ra[arr[i]] = i + 1;
                rb[brr[i]] = i + 1;

            seq = new int[n + 1];
            for (int i = 1; i <= n; i++) {
                seq[ra[i]] = rb[i];
            friendcross.SegmentTree root = new friendcross.SegmentTree(1, n);
            ArrayList<friendcross.Event>[] events = new ArrayList[n + 1];
            for (int i = 0; i <= n; i++) events[i] = new ArrayList<>();
            for (int i = 1; i <= n; i++) {
                int up = Math.min(n, i + k);
                int down = Math.max(0, i - k - 1);
                events[up].add(new friendcross.Event(i, +1));
                events[down].add(new friendcross.Event(i, -1));
            long tinv = 0;
            BIT x = new BIT(n);
            for (int i = 1; i <= n; i++) {
                x.update(seq[i], +1);
                tinv += i - x.query(seq[i]);
            long res = 0;
            for (int cow = 1; cow <= n; cow++) {
                root.update(ra[cow], +1);
                for (friendcross.Event e : events[cow])
                    res += e.sign * root.query(1, ra[e.cow], rb[e.cow]);
            out.println(tinv - res);

        static class Event {
            public int cow;
            public int sign;

            public Event(int cow, int sign) {
                this.cow = cow;
                this.sign = sign;


        static class SegmentTree {
            public int[] arr;
            public int[] pl;
            public int[] pr;
            public BIT bit;
            public int start;
            public int end;
            public friendcross.SegmentTree lchild;
            public friendcross.SegmentTree rchild;

            public SegmentTree(int start, int end) {
                this.start = start;
                this.end = end;
                arr = new int[end - start + 2];
                if (start == end) {
                    arr[1] = seq[start];
                } else {
                    int mid = (start + end) >> 1;
                    lchild = new friendcross.SegmentTree(start, mid);
                    rchild = new friendcross.SegmentTree(mid + 1, end);
                    pl = new int[lchild.arr.length];
                    pr = new int[rchild.arr.length];
                    int lidx = 1, ridx = 1;

                    int idx = 1;
                    int[] larr = lchild.arr, rarr = rchild.arr;
                    while (lidx < larr.length && ridx < rarr.length) {
                        if (larr[lidx] < rarr[ridx]) {
                            pl[lidx] = idx;
                            arr[idx++] = larr[lidx++];
                        } else {
                            pr[ridx] = idx;
                            arr[idx++] = rarr[ridx++];
                    while (lidx < larr.length) {
                        pl[lidx] = idx;
                        arr[idx++] = larr[lidx++];
                    while (ridx < rarr.length) {
                        pr[ridx] = idx;
                        arr[idx++] = rarr[ridx++];
                bit = new BIT(end - start + 2);

            public int query(int s, int e, int k) {
                if (start == s && end == e) {
                    if (k < arr[1]) return bit.count;
                    int lo = 1, hi = arr.length - 1;
                    while (lo < hi) {
                        int mid = (lo + hi + 1) / 2;
                        if (arr[mid] > k) hi = mid - 1;
                        else lo = mid;
                    return bit.count - bit.query(lo);
                int mid = (start + end) >> 1;
                if (mid >= e) return lchild.query(s, e, k);
                else if (mid < s) return rchild.query(s, e, k);
                else return lchild.query(s, mid, k) + rchild.query(mid + 1, e, k);

            public int update(int p, int val) {
                if (start == p && end == p) {
                    bit.update(1, +1);
                    return 1;
                int mid = (start + end) >> 1;
                int apos = -1;
                if (mid >= p) apos = pl[lchild.update(p, val)];
                else apos = pr[rchild.update(p, val)];
                bit.update(apos, +1);
                return apos;



    static class InputReader {
        public BufferedReader reader;
        public StringTokenizer tokenizer;

        public InputReader(InputStream stream) {
            reader = new BufferedReader(new InputStreamReader(stream), 32768);
            tokenizer = null;

        public String next() {
            while (tokenizer == null || !tokenizer.hasMoreTokens()) {
                try {
                    tokenizer = new StringTokenizer(reader.readLine());
                } catch (IOException e) {
                    throw new RuntimeException(e);
            return tokenizer.nextToken();

        public int[] readIntArray(int tokens) {
            int[] ret = new int[tokens];
            for (int i = 0; i < tokens; i++) {
                ret[i] = nextInt();
            return ret;

        public int nextInt() {
            return Integer.parseInt(next());


    static class BIT {
        private int[] tree;
        private int N;
        public int count;

        public BIT(int N) {
            this.N = N;
            this.tree = new int[N + 1];
            this.count = 0;

        public int query(int K) {
            int sum = 0;
            for (int i = K; i > 0; i -= (i & -i))
                sum += tree[i];
            return sum;

        public void update(int K, int val) {
            this.count += val;
            for (int i = K; i <= N; i += (i & -i))
                tree[i] += val;


    static class OutputWriter {
        private final PrintWriter writer;

        public OutputWriter(OutputStream outputStream) {
            writer = new PrintWriter(new BufferedWriter(new OutputStreamWriter(outputStream)));

        public OutputWriter(Writer writer) {
            this.writer = new PrintWriter(writer);

        public void close() {

        public void println(long i) {
