USACO OPEN11 Problem 'solder' Analysis

by Michael Cohen

The starting point here is a dynamic programming algorithm. Arbitrarily root the tree and consider "cutting off" a particular subtree in a soldering. This leaves one (or none if a wire was cut off at its endpoints) "cut wire" which extends out of the subtree to the parent and a set of wires that are wholly within the subtree. Now, note that all that is relevant is the length of the "cut wire" within the subtree and the total cost of all the other wires. This is because the cut wire is the only wire whose cost depends on the rest of the soldering.

This gives a relatively simple dynamic programming solution: for each vertex (defining a subtree) store, for each possible cut wire length, the minimum cost of the other wires; if there is no cut wire this can be taken as a wire of length 0. We will compute these from the bottom up. To compute these values, note that if there is a cut wire it must extend down to one of the children; the cost for a cut wire going through a particular child is the cost for the cut wire through the child's subtree plus the minimum cost soldering covering each of the other subtrees. If there is no cut wire, then the edge going to the parent must be soldered onto the middle of another wire; then one can just check all pairs of lengths and distinct children to find two "cut wires" for two children to merge into into one wire. Now, note that the maximum length cut wire for each subtree is the number of nodes it contains, so the number of pairs of lengths for any two distinct children is at most the number of pairs of nodes in the two children; summing over all children this is the number of nodes whose lowest common ancestor is the root of the subtree. Then the total work done over the whole algorithm is only the total number of pairs of nodes, or O(N2).

Now, at this point it will be convenient to assume, in the discussion of the algorithm, that each vertex has at most two children. In fact, this is not a problem: a vertex V with more children can be "split up" by giving it a direct edge to one of its children and attaching the remainder to a new vertex V' with an edge to it from V of length 0 (the length does not break the algorithm although all edges in the problem were of length 1), then iterating this until no vertex has more than two children.

To further reduce the runtime, one must note the convexity properties of the squaring of the length. If one looks at a length/cost pair (l, c) for a subtree, it corresponds to the function (L+l)2+c where L is the length of the cut wire outside the subtree. But one only cares about those functions that are the minimum for some value of L: since (L+l)2 + c = L2 + 2L * l + (l2+ c), this is the lower envelope of these functions, equivalent to a convex hull. All pairs not in the envelope can be deleted. One can then binary search the convex hull to find the optimal pairing with any particular length of the wire outside the subtree. Then to find the optimal pair of lengths in the two children to merge into one wire, one can simply take all the lengths in the smaller subtree and binary search the convex hull in the larger subtree to find the best thing within that subtree to pair it with. Finally, to efficiently find the convex hulls for all subtrees, one can represent the convex hulls with binary search trees (std::set does fine here) and to get the possibilities from either child, one can offset the values in the larger child subtree's convex hull (by storing offset values that are added to all the pairs in the hull, since both length and cost change as you merge subtrees) and then insert each pair (offset) from the smaller subtree into it. The total number of operations on the binary search trees is then at most the sum of the sizes of the smaller child subtree from each node (in fact it can be smaller as the convex hull can have fewer elements than the size of the subtree). This can be shown to be O(N log N): one can consider the number of times each position gets merged into a larger group, and note that it is always less than log N since with each merge only the values in the smaller half are incremented. Each tree operation is O(log N), so the overall runtime is O(N log2 N).

Below is Neal Wu's N2 implementation:

#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;

FILE *in = fopen ("solder.in", "r"), *out = fopen ("solder.out", "w");

const int MAXN = 50005;
const long long LLINF = 1LL << 60;

int N, down [MAXN];
long long *dp [MAXN], mindp [MAXN];
vector <int> adj [MAXN];

void init_dfs (int num, int par)
{
    down [num] = 1;
    int par_ind = -1;

    for (int i = 0; i < (int) adj [num].size (); i++)
    {
        int child = adj [num][i];

        if (child == par)
        {
            par_ind = i;
            continue;
        }

        init_dfs (child, num);
        down [num] = max (down [num], down [child] + 1);
    }

    if (par_ind != -1)
        adj [num].erase (adj [num].begin () + par_ind);
}

void solve_dfs (int num)
{
    for (int i = 0; i < (int) adj [num].size (); i++)
        solve_dfs (adj [num][i]);

    long long dp1 = 1;

    if (adj [num].size () > 1)
    {
        for (int i = 0; i < (int) adj [num].size (); i++)
            dp1 += mindp [adj [num][i]];

        long long best_two = LLINF;

        for (int i = 0; i < (int) adj [num].size (); i++)
            for (int j = i + 1; j < (int) adj [num].size (); j++)
            {
                int child1 = adj [num][i], child2 = adj [num][j];

                for (int a = 1; a <= down [child1]; a++)
                    for (int b = 1; b <= down [child2]; b++)
                        best_two = min (best_two, dp [child1][a] + dp
[child2][b] + 2LL * a * b - mindp [child1] - mindp [child2]);
            }

        dp1 += best_two;
    }

    dp [num] = new long long [down [num] + 1];
    dp [num][1] = dp1;

    if (adj [num].size () == 1)
    {
        dp [num][1] = LLINF;
        dp [num][0] = mindp [adj [num][0]];
    }
    else
        dp [num][0] = dp [num][1] - 1;

    for (int k = 1; k < down [num]; k++)
    {
        long long sum = 0, best_link = LLINF;

        for (int i = 0; i < (int) adj [num].size (); i++)
        {
            int child = adj [num][i];
            sum += mindp [child];

            if (k <= down [child])
                best_link = min (best_link, dp [child][k] - mindp [child]);
        }

        dp [num][k + 1] = sum + best_link + 2 * k + 1;
    }

    mindp [num] = LLINF;

    for (int k = 1; k <= down [num]; k++)
        mindp [num] = min (mindp [num], dp [num][k]);

    for (int i = 0; i < (int) adj [num].size (); i++)
        delete dp [adj [num][i]];
}

int main ()
{
    fscanf (in, "%d", &N);

    for (int i = 1, a, b; i < N; i++)
    {
        fscanf (in, "%d %d", &a, &b); a--; b--;
        adj [a].push_back (b);
        adj [b].push_back (a);
    }

    init_dfs (0, -1);
    solve_dfs (0);
    fprintf (out, "%lld\n", dp [0][0]);
    return 0;
}
And below is Michael Cohen's impressive full implementation:
#include <fstream>
#include <vector>
#include <set>
#define endl '\n'
using namespace std;

struct poss {
	long long depth;
	long long cost;
	long long takeover;
	bool tcheck;
};

bool operator<(poss a, poss b) {
	if (a.tcheck || b.tcheck) return (a.takeover < b.takeover);
	if (a.depth > b.depth) return true;
	if (a.depth < b.depth) return false;
	return (a.cost < b.cost);
}

int N;
vector<int> edges[50000];
bool visited[50000];
long long depth[50000];
long long offset[50000];
set<poss>* best[50000];

void recurse(int node) {
	visited[node] = true;
	long long bestPair = -1;
	long long allSoFar = 0;
	for (int i = 0; i < edges[node].size(); i++) {
		if (visited[edges[node][i]]) continue;
		depth[edges[node][i]] = depth[node]+1;
		recurse(edges[node][i]);
		
		long long tadd;
		{
			poss when = { 0, 0, -depth[node], true };
			set<poss>::iterator which = best[edges[node][i]]->upper_bound(when);
			which--;
			tadd =
(depth[node]-which->depth)*(depth[node]-which->depth)+which->cost+offset[edges[node][i]];
		}
		if (bestPair != -1) bestPair += tadd;
		
		if (best[node] == NULL) {
			best[node] = best[edges[node][i]];
			offset[node] = offset[edges[node][i]];
		}
		else {
			set<poss>* s = best[node], * t = best[edges[node][i]];
			long long os = offset[node]+tadd, ot = offset[edges[node][i]]+allSoFar;
			if (s->size() < t->size()) {
				set<poss>* tem = s;
				s = t;
				t = tem;
				int to = os;
				os = ot;
				ot = to;
			}
			
			for (set<poss>::iterator it = t->begin(); it != t->end(); it++) {
				poss when = { 0, 0, it->depth-2*depth[node], true };
				set<poss>::iterator which = s->upper_bound(when);
				which--;
				long long thisPair =
(it->depth+which->depth-2*depth[node])*(it->depth+which->depth-2*depth[node])+it->cost+which->cost+offset[node]+offset[edges[node][i]];
				if (bestPair == -1 || thisPair < bestPair) bestPair = thisPair;
			}
			
			for (set<poss>::iterator it = t->begin(); it != t->end(); it++) {
				poss p = *it;
				p.cost += ot-os;
				set<poss>::iterator where = s->insert(p).first;
				bool killed = false;
				while (where != s->begin()) {
					set<poss>::iterator prev = where;
					prev--;
					if (prev->depth == where->depth) {
						s->erase(where);
						killed = true;
						break;
					}
					p.takeover =
(where->cost-prev->cost+where->depth*where->depth-prev->depth*prev->depth)/(2*prev->depth-2*where->depth);
					while ((2*prev->depth-2*where->depth)*p.takeover <
where->cost-prev->cost+where->depth*where->depth-prev->depth*prev->depth)
p.takeover++;
					s->erase(where);
					where = s->insert(p).first;
					
					if (where->takeover <= prev->takeover) s->erase(prev);
					else break;
				}
				if (killed) continue;
				if (where == s->begin()) {
					p.takeover = -1000000000;
					s->erase(where);
					where = s->insert(p).first;
				}
				set<poss>::iterator next = where;
				next++;
				while (next != s->end()) {
					if (next->depth == where->depth) {
						s->erase(next);
						next = where;
						next++;
						continue;
					}
					poss n = *next;
					n.takeover =
(next->cost-where->cost+next->depth*next->depth-where->depth*where->depth)/(2*where->depth-2*next->depth);
					while ((2*where->depth-2*next->depth)*n.takeover <
next->cost-where->cost+next->depth*next->depth-where->depth*where->depth)
n.takeover++;
					if (n.takeover <= where->takeover) {
						s->erase(where);
						break;
					}
					s->erase(next);
					next = s->insert(n).first;
					set<poss>::iterator nnext = next;
					nnext++;
					if (nnext != s->end() && nnext->takeover <=
next->takeover) {
						s->erase(next);
						next = nnext;
					}
					else break;
				}
			}
			
			best[node] = s;
			offset[node] = os;
			delete t;
		}
		allSoFar += tadd;
	}
	
	if (best[node] == NULL) {
		best[node] = new set<poss>();
		poss p = { depth[node], 0, -1000000000, false };
		best[node]->insert(p);
	}
	else if (bestPair != -1) {
		poss p = { depth[node], bestPair-offset[node], 0, false };
		while (!best[node]->empty()) {
			p.takeover =
(p.cost-best[node]->rbegin()->cost+p.depth*p.depth-best[node]->rbegin()->depth*best[node]->rbegin()->depth)/(2*best[node]->rbegin()->depth-2*p.depth);
			while ((2*best[node]->rbegin()->depth-2*p.depth)*p.takeover <
p.cost-best[node]->rbegin()->cost+p.depth*p.depth-best[node]->rbegin()->depth*best[node]->rbegin()->depth)
p.takeover++;
			if (p.takeover > best[node]->rbegin()->takeover) break;
			best[node]->erase(*(best[node]->rbegin()));
		}
		if (best[node]->empty()) p.takeover = -1000000000;
		best[node]->insert(p);
	}
}

int main()
{
	ifstream inp("solder.in");
	ofstream outp("solder.out");
	
	inp >> N;
	for (int i = 0; i < N-1; i++) {
		int a, b;
		inp >> a >> b;
		a--, b--;
		edges[a].push_back(b);
		edges[b].push_back(a);
	}
	
	recurse(0);
	if (edges[0].size() == 1) {
		poss when = { 0, 0, 0, true };
		set<poss>::iterator which = best[0]->upper_bound(when);
		which--;
		outp << which->depth*which->depth+which->cost+offset[0]
<< endl;
	}
	else {
		poss p = *(best[0]->rbegin());
		outp << p.cost+offset[0] << endl;
	}
	return 0;
}