(Analysis by David Hu)

Note that it is optimal for Farmer John to milk his cows such that the cow with $i$th smallest milk production value spends $i$ minutes on the milking machine. Indeed, if there are two cows $i$ and $j$ such that $a_i > a_j$ but cow $i$ spends less minutes than cow $j$ on the milking machine, the total amount of milk Farmer John produces could be increased by swapping the amount of time cows $i$ and $j$ spend on the milking machine.

So the maximum amount of milk Farmer John can produce is $G(a) = \sum_{i=1}^{n} i \cdot a'_i$, where $a'$ is the array that results upon sorting $a$.

Let's first suppose that $a$ is originally sorted, and let the value of $G(a)$ initially be $S$.

Now let's see what happens when we replace some $a_i$ with some other value $v$. First suppose $v \geq a_i$. Then, in the sorted version of $a$, $v$ will belong in some position $p \geq i$, which we can find by binary search. Furthermore. all numbers originally in positions $i+1, i+2, \dots, p$ will shift over down to one position to positions $i, i+1, \dots, p-1$. As a result, $G(a)$ will now become $S - i \cdot a_i - \sum_{j=i+1}^{p} a_j + p \cdot v$. We can use prefix sums to compute $\sum_{j=i+1}^{p} a_j$ in $O(1)$ per query.

The case when $v < a_i$ is similar.

Now we must handle what happens when $a$ is not originally sorted. If we figure out, for all $i$, the position $p_i$ such that $a_i$ would be in the sorted version of $a$, then we can simply sort $a$ (getting an array $a'$) and view every query changing $a_i$ to $j$ as a query changing $a'_{p_i}$. There are a number of ways to find $p$: one way is to sort a list $c$ of the numbers from $1$ to $N$ by the value $a_i$; then if $c_j$ is the $j$th number in the list, the $p_{c_j} = j$.

We must also remember to compute $S$ and the prefix sums using $a'$.

Overall time complexity is $O((N + Q) \log N)$ due to sorting and binary search.

My C++ Code is below. The usage of the built in C++ lower_bound function can greatly simplify our implementation.

#include <bits/stdc++.h>

using namespace std;

const int MAXN = 1.5e5 + 13;
typedef long long ll;

int N, Q;
int ord[MAXN], pos[MAXN];
ll arr[MAXN], pref[MAXN];
ll tot;

int main()
{
ios_base::sync_with_stdio(false); cin.tie(0);
cin >> N;
for (int i = 0; i < N; i++)
{
cin >> arr[i];
}
iota(ord, ord + N, 0);
sort(ord, ord + N, [&](int i, int j)
{
return arr[i] < arr[j];
});
for (int i = 0; i < N; i++)
{
pos[ord[i]] = i;
}
sort(arr, arr + N);
for (int i = 0; i < N; i++)
{
pref[i + 1] = pref[i] + arr[i];
tot += (i + 1) * arr[i];
}
cin >> Q;
while(Q--)
{
int idx; ll val;
cin >> idx >> val; idx--;
idx = pos[idx];
ll ans = tot;
//index that val would be at in the new array
int newidx = lower_bound(arr, arr + N, val) - arr - (bool) (val > arr[idx]);
ans -= (idx + 1) * arr[idx];
if (newidx >= idx)
{
ans -= (pref[newidx + 1] - pref[idx + 1]);
}
else
{
ans += (pref[idx] - pref[newidx]);
}
ans += (newidx + 1) * val;
cout << ans << '\n';
}
return 0;
}


My Python Code:

N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
pos[ord[i]] = i
arr.sort()

def binary_search(x): #counts number of #s <x, or min index i st a[i] >= x
lo = 0
hi = N
while(hi > lo):
mid = (hi + lo) // 2
if (arr[mid] >= x):
hi = mid
else:
lo = mid + 1
return lo

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
pref[i + 1] = pref[i] + arr[i]
tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
idx, val = map(int, input().split())
idx -= 1
idx = pos[idx]
newidx = binary_search(val)
if (val > arr[idx]):
newidx -= 1
ans = tot
ans -= (idx + 1) * arr[idx]
if (newidx >= idx):
ans -= (pref[newidx + 1] - pref[idx + 1])
else:
ans += (pref[idx] - pref[newidx])
ans += (newidx + 1) * val
print(ans)


Slightly shorter if bisect is used:

import bisect

N = int(input())
arr = list(map(int, input().split()))
ord = [i for i in range(N)]
ord.sort(key = lambda x: arr[x])
pos = [0 for i in range(N)]
for i in range(N):
pos[ord[i]] = i
arr.sort()

pref = [0 for i in range(N + 1)]
tot = 0
for i in range(N):
pref[i + 1] = pref[i] + arr[i]
tot += (i + 1) * arr[i]
Q = int(input())
for i in range(Q):
idx, val = map(int, input().split())
idx -= 1
idx = pos[idx]
newidx = bisect.bisect_left(arr, val)
if (val > arr[idx]):
newidx -= 1
ans = tot
ans -= (idx + 1) * arr[idx]
if (newidx >= idx):
ans -= (pref[newidx + 1] - pref[idx + 1])
else:
ans += (pref[idx] - pref[newidx])
ans += (newidx + 1) * val
print(ans)


Danny Mittal's Java code:

import java.io.BufferedReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.StringTokenizer;
import java.util.TreeMap;

public class ArrayQueriesSilver {

public static void main(String[] args) throws IOException {
Long[] sorted = xs.clone();
Arrays.sort(sorted);
long base = 0;
long[] sums = new long[n + 1];
TreeMap<Long, Integer> treeMap = new TreeMap<>();
for (int j = 0; j < n; j++) {
sums[j + 1] = sums[j] + sorted[j];
base += ((long) (j + 1)) * sorted[j];
treeMap.put(sorted[j], j);
}
treeMap.put(Long.MIN_VALUE, -1);
StringBuilder out = new StringBuilder();
for (int q =  Integer.parseInt(in.readLine()); q > 0; q--) {
int j = Integer.parseInt(tokenizer.nextToken()) - 1;
long prev = xs[j];
long next = Long.parseLong(tokenizer.nextToken());
int prevIndex = treeMap.get(prev);
int nextIndex = treeMap.lowerEntry(next).getValue() + 1;