# Full solution in Python. ~8.6 sec on the worst testcase on the testing server. # Load data n,k,*data = map(int, open("transsis.txt").read().split()) targets = set([v-1 for v in data[-k:]]) g = [[] for _ in range(n)] for a,b in zip(data[:-k:2],data[1:-k:2]): g[a-1].append(b-1) g[b-1].append(a-1) del data # Need to free space here, otherwise we run out of memory # Sort vertices topologically order = [(0,None)] for i in range(n): order.extend([(u,order[i][0]) for u in g[order[i][0]] if u != order[i][1]]) parent = dict(order) # Bottom-up traversal to compute aggregates for each subtree n_sub, d_sub, d_all, d2_sub, d2_all = [0]*n, [0]*n, [0]*n, [0]*n, [0]*n for v, p in reversed(order): n_sub[v] = sum([n_sub[c] for c in g[v] if c != parent[v]]) + int(v in targets) d_sub[v] = sum([d_sub[c] + n_sub[c] for c in g[v] if c != parent[v]]) d2_sub[v] = sum([d2_sub[c] + n_sub[c] + 2*d_sub[c] for c in g[v] if c != parent[v]]) # Top-down traversal to compute sum of distances and distance squares to all cities. for v, p in order: d_all[v] = d_sub[0] if v == 0 else d_all[parent[v]] + k - 2*n_sub[v] d2_all[v] = d2_sub[0] if v == 0 else d2_all[parent[v]] + k - 4*(d_sub[v] + n_sub[v]) + 2*d_all[parent[v]] print(min(d2_all), file=open('transval.txt','w'))