(Analysis by Spencer Compton)

First, we make an observation about what the shortest distance to some $K$-tuple is. The distance to $K$-tuple $(j_1, j_2, \dots, j_k)$ is $d$ if in each graph $G_i$ there is a walk of exactly $d$ steps that ends at $j_i$. To use this, we exploit more structure about what lengths of walks can end at a vertex. If there is a walk of length $x$ in $G_i$ that ends at $j_i$, then we know there are also walks of length $x,x+2,x+4,\dots$ because one can repeatedly take one step away from $j_i$ and one step towards $j_i$.

Moreover, if $even_i[v]$ denotes the length of the shortest path of even length in $G_i$ that ends at $v$, and $odd_i[v]$ for odd shortest path respectively, then valid distances for $v$ are the union of $\{even_i[v],even_i[v]+2,even_i[v]+4,\dots\}$ and $\{odd_i[v],odd_i[v]+2,odd_i[v]+4,\dots\}$. (If there is no even path or odd path, ignore that corresponding set.) We can compute the quantities $even_i[v]$ and $odd_i[v]$ by creating a duplicate graph (one representing odd and one representing even with edges going between the copies) and using a BFS.

Now, we want to use this structure to compute the sum of all distances. Core to this question, is knowing for some $K$-tuple what the minimum possible distance is. If we decide this distance will be even (and there exists an even path for each node in the tuple), then the distance is simply the maximum $even_i[v]$ in the tuple. More concretely, we denote the sum of the distance of these tuples as $compute\_sum(L_{even})$ where $L_{even}$ contains a list of pairs containing each node's $even_i[v]$ and corresponding graph (if it has an even path). Then, $compute\_sum$ calculates the sum, over all valid tuples, of the maximum value. (And analogously the same statement, if we decide this distance will be odd.)

But making such a decision for a tuple to be even or odd is difficult. Consider, instead, that we immediately calculate $compute\_sum(L_{even}) + compute\_sum(L_{odd})$. If an entire tuple could use even paths or odd paths, then we overcounted the answer for that tuple by exactly the larger quantity (e.g. if odd was the worse decision for that tuple, then the maximum $odd_i[v]$ in that tuple). Finally, we can correct this by subtracting $compute\_sum(L_{max})$, where $L_{max}$ denotes a list with $\max(even_i[v], odd_i[v])$ for corresponding nodes that have even paths and odd paths.

All that remains is how to calculate $compute\_sum(L)$ for some list $L$. One such way is computing the number of tuples $less[x]$ where the maximum is $\le x$, then the answer is $\sum_i (less[i]-less[i-1])\times i$. To calculate $less[x]$, we note that this is equal to the product of $cnt_i[x]$ over all $i$, where $cnt_i[x]$ represents the number of nodes from graph $i$ whose corresponding value is $\le x$. Directly computing this would be too slow, but we can optimize.

For each graph $G_i$, we compute $cnt_i[x]$ for all $x < 2 N_i$. For all larger $x$, $cnt_i[x]=cnt_i[2 N_i -1]$. To hold this, we maintain a suffix product array (i.e. similar to a prefix sum array, but for suffixes and multiplication instead) $suffix\_prod$ and modify it such that all elements in the suffix starting with $2 N_i$ will be multiplied by $cnt_i[2 N_i-1]$. For the $2 N_i$ values of $cnt_i$, we can similarly have an array $prefix\_prod$ and multiply $prefix\_prod[x]$ by $cnt_i[x]$. Then, we can finally compute $less[x]$ as $prefix\_prod[x] \times suffix\_prod[x]$. This runs in linear time, so in total our algorithm runs in $O(\sum N_i + \sum M_i)$ time.

It is also possible to calculate $compute\_sum(L)$ using a segment tree or modular inverses.

Spencer's code:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll mod = 1e9+7;

int k;
int n[50000];
int inf = 1e8;
ll compute_sum(vector<pair<int, int> > li){
int maxn = 0;
vector<ll> prefix_prod;
vector<ll> suffix_prod;
vector<ll> graphs[k];
for(int i = 0; i<li.size(); i++){
graphs[li[i].second].push_back(li[i].first);
}
for(int i = 0; i<k; i++){
vector<ll> cnt(2*n[i]);
maxn = max(maxn,2*n[i]);
while(prefix_prod.size()<maxn){
prefix_prod.push_back(1);
}
while(suffix_prod.size()<=maxn){
suffix_prod.push_back(1);
}
for(int j = 0; j<graphs[i].size(); j++){
cnt[graphs[i][j]]++;
}
for(int j = 0; j<2*n[i]; j++){
if(j>0){
cnt[j] += cnt[j-1];
}
prefix_prod[j] *= cnt[j];
prefix_prod[j] %= mod;
}
suffix_prod[2*n[i]] *= cnt[2*n[i]-1];
suffix_prod[2*n[i]] %= mod;
}
for(int i = 1; i<suffix_prod.size(); i++){
suffix_prod[i] *= suffix_prod[i-1];
suffix_prod[i] %= mod;
}
ll ans = 0LL;
for(int i = 1; i<maxn; i++){
ll cur_num = (prefix_prod[i]*suffix_prod[i])-(prefix_prod[i-1]*suffix_prod[i-1]);
cur_num %= mod;
ans += cur_num * (ll)i;
ans %= mod;
}
if(ans<0LL){
ans += mod;
}
return ans;
}
int main(){
ios_base::sync_with_stdio(false);
cin.tie(0);
vector<pair<int, int> > evens;
vector<pair<int, int> > odds;
vector<pair<int, int> > both;
cin >> k;
for(int i = 0; i<k; i++){
int m;
cin >> n[i] >> m;
vector<int> adj[2*n[i]];
for(int j = 0; j<m; j++){
int a, b;
cin >> a >> b;
a--;
b--;
adj[a].push_back(n[i]+b);
adj[b].push_back(n[i]+a);
adj[n[i]+a].push_back(b);
adj[n[i]+b].push_back(a);
}
vector<int> dist(2*n[i], inf);
vector<int> li;
dist[0] = 0;
li.push_back(0);
for(int j = 0; j<li.size(); j++){
int now = li[j];
for(int a = 0; a<adj[now].size(); a++){
int to = adj[now][a];
if(dist[to]==inf){
dist[to] = dist[now]+1;
li.push_back(to);
}
}
}
for(int j = 0; j<n[i]; j++){
if(dist[j]<inf){
evens.push_back(make_pair(dist[j],i));
}
if(dist[j+n[i]]<inf){
odds.push_back(make_pair(dist[j+n[i]],i));
}
if(max(dist[j],dist[j+n[i]])<inf){
both.push_back(make_pair(max(dist[j],dist[j+n[i]]),i));
}
}
}
ll ans = compute_sum(evens)+compute_sum(odds)-compute_sum(both);
ans %= mod;
if(ans<0LL){
ans += mod;
}
cout << ans << "\n";
}


Danny Mittal's code (with modular inverse):

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class SingleSourceShortestPath {
public static final long MOD = 1000000007L;

public static long inverse(long base) {
int exponent = (int) MOD - 2;
long res = 1;
while (exponent != 0) {
if (exponent % 2 == 1) {
res *= base;
res %= MOD;
}
exponent /= 2;
base *= base;
base %= MOD;
}
return res;
}

public static void main(String[] args) throws IOException {
BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
long[] amts = new long[200000];
long[] amts2 = new long[200000];
boolean[] amtsZero = new boolean[200000];
boolean[] amts2Zero = new boolean[200000];
Arrays.fill(amts, 1);
Arrays.fill(amts2, 1);
int k = Integer.parseInt(in.readLine());
boolean anyBipartite = false;
for (int g = 0; g < k; g++) {
in.readLine();
StringTokenizer tokenizer = new StringTokenizer(in.readLine());
int n = Integer.parseInt(tokenizer.nextToken());
int m = Integer.parseInt(tokenizer.nextToken());
List<Integer>[] adj = new List[(2 * n) + 1];
for (int a = 1; a <= 2 * n; a++) {
adj[a] = new ArrayList<>();
}
for (int j = 1; j <= m; j++) {
tokenizer = new StringTokenizer(in.readLine());
int a = Integer.parseInt(tokenizer.nextToken());
int b = Integer.parseInt(tokenizer.nextToken());
adj[a].add(n + b);
adj[n + b].add(a);
adj[n + a].add(b);
adj[b].add(n + a);
}
int[] dist = new int[(2 * n) + 1];
Arrays.fill(dist, -1);
dist[1] = 0;
LinkedList<Integer> q = new LinkedList<>();
q.add(1);
while (!q.isEmpty()) {
int a = q.remove();
for (int b : adj[a]) {
if (dist[b] == -1) {
dist[b] = dist[a] + 1;
q.add(b);
}
}
}
if (dist[n + 1] == -1) {
anyBipartite = true;
}
long[] freq = new long[2 * n];
long[] freq2 = new long[2 * n];
for (int a = 1; a <= n; a++) {
if (dist[a] != -1) {
freq[dist[a]]++;
}
if (dist[n + a] != -1) {
freq[dist[n + a]]++;
}
if (dist[a] != -1 && dist[n + a] != -1) {
freq2[Math.max(dist[a], dist[n + a])]++;
}
}
for (int d = 2; d < 2 * n; d++) {
freq[d] += freq[d - 2];
freq2[d] += freq2[d - 1];
}
for (int d = 0; d < 2 * n; d++) {
if (freq[d] == 0L) {
amtsZero[d] = true;
} else {
amts[d] *= freq[d];
amts[d] %= MOD;
if (d >= 2 && freq[d - 2] != 0L) {
amts[d] *= inverse(freq[d - 2]);
amts[d] %= MOD;
}
}
if (freq2[d] == 0L) {
amts2Zero[d] = true;
} else {
amts2[d] *= freq2[d];
amts2[d] %= MOD;
if (d >= 1 && freq2[d - 1] != 0L) {
amts2[d] *= inverse(freq2[d - 1]);
amts2[d] %= MOD;
}
}
}
}
for (int d = 2; d < 200000; d++) {
amts[d] *= amts[d - 2];
amts[d] %= MOD;
amts2[d] *= amts2[d - 1];
amts2[d] %= MOD;
}
for (int d = 0; d < 200000; d++) {
if (amtsZero[d]) {
amts[d] = 0;
}
if (amts2Zero[d]) {
amts2[d] = 0;
}
}
if (anyBipartite) {
Arrays.fill(amts2, 0);
}
long answer = 0;
for (int d = 0; d < 200000; d++) {
long dl = d;
answer += dl * amts[d];
answer -= dl * amts2[d];
if (d >= 2) {
answer -= dl * amts[d - 2];
answer += dl * amts2[d - 1];
}
answer %= MOD;
}
answer += MOD;
answer %= MOD;
System.out.println(answer);
}
}