6.S078 Lecture 13 (4/2): Linear Decision Trees for k-SUM ======================================================== Recall the model: LINEAR DECISION TREES (LDT) Definition: An LDT for a function f : R^m -> {0,1}^{m'} of width L is a binary tree where each inner node is associated with some inequality [alpha_{i_1} x_{i_1} + ... + alpha_{i_L} x_{i_L} \geq t] where the x_{i_j}'s are variables from the set {x_1,...,x_m}, alpha_{i_j} in {-1,0,1}, t is in R. Each leaf of the LDT contains a string from {0,1}^{m'}. *** 3-SUM Theorem [GP'14] 3-SUM has an LDT of width 4 and depth O~(n^{1.5}). The most amazing thing about this theorem is that it took almost 40 years to prove after Fredman... once you see it, you'll know what I mean. I believe the only reason it stayed open was because people believed it was not possible, thanks to: Theorem [Erickson'99] Every LDT of width 3 for 3-SUM has depth Omega(n^2). Since 3-SUM is comparing triples of numbers anyway, people didn't figure that width 4 would help... Our starting point is the 3-SUM algorithm we gave that runs in O(n^2) time. We will use the 3-SUM* variant of the 3-SUM problem: Given set A, are there a_i, a_j, a_k in A such that a_i = a_j + a_k? WLOG let's assume the numbers are distinct. (Can check non-distinct numbers for 3-SUM quickly.) Here's the O(n^2) time algorithm: 1. Sort A in O(n log n) time. 2. For all a_i in A, make two pointers: p1 at the beginning of the sorted A, p2 at the end. 3. Repeat until the pointers pass each other: Let a_j be current number at p1, a_k be number at p2. If a_i = a_j + a_k then return the triple If a_i < a_j + a_k then move p2 to left, else move p1 to right. 4. Return "no solution" For each a_i, this procedure takes O(n) time to find the other two. Note, this algorithm can be implemented in a width 3 LDT of depth O(n^2). We want to speed-up the loop in part 3. Idea: partition sorted elements together into groups A_1,..,A_{n/d} with O(d) elements in each group. For two groups A_j and A_k, want to *quickly* check for a_j in A_j, a_k in A_k such that a_i = a_j + a_k. Then we'd only have to run the loop in part 3 for n/d times... Here's our algorithm: 1. Sort A. Partition sorted A into contiguous n/d groups A_1,...,A_{n/d} of O(d) elements each. 2. For all i in [n/d], make the list D_i = {a_j - a_k | a_j, a_k in A_i}. Sort each D_i in O~(d^2) time. Merge all these lists into a list L. 3. For all j,k in [n/d], define S_{j,k} = {a_j + a_k | a_j in A_j, a_j in A_k} 4. For all a_i in A, make two pointers on *groups*: p1 at A_1, p2 at A_{n/d}. 5. Repeat until pointers cross: Suppose p1 points at A_j, p2 points at A_k. If (a_i is in S_{j,k}) then return "yes" <----- have to investigate this line further! [Otherwise, a_i != a_j + a_k for all a_j in A_j and a_k in A_k] a_{min} = smallest element of A_k, a_{max} = largest of A_j. If a_i < a_{min} + a_{max} then move p2 to left, else move p1 to right. 6. Return "no solution" Note the sorting of step 2 takes O~(n/d * d^2) = O~(n d) depth with a width-4 LDT. And the number of iterations of the loop in step 5 is <= n/d. But Step 3 looks suspect... there are n^2/d^2 such lists, and each is Theta(d^2) long in the worst case. And how do we implement the check of "a_i in S_{j,k}" efficiently? KEY INSIGHT: Any two elements of S_{j,k} can be compared using the sorted order on L. Take two numbers a_j + a_k, a'_j + a'_k from S_{j,k}. We have a_j + a_k <= a'_j + a'_k <=> a_j - a'_j <= a'_k - a_k. But a_j - a'_j in D_j, a'_k - a_k in D_k, so both appear in L, which we have already sorted. Hence after step 3, we can already infer the sorted order of S_{j,k}, for all j,k. Note |S_{j,k}| = O(d^2). We can therefore determine if a_i is in S_{j,k} in only O(log d) depth of an LDT, by binary search on S_{j,k}! Since we can infer the sorted order of S_{j,k} from L, we can infer a binary search tree (with the median of sorted S_{j,k} at the top, medians of the two sublists at the left and right children, etc). Binary search to determine (a_i is in S_{j,k}) can be done by a width 3 LDT of depth O(log d), since each node has comparisons of the form a_i <= a_k - a_j. Thus the whole of step 6 only takes O~(n/d * log d) extra depth, for n/d binary searches, and one extra comparison to move the pointers. Setting d = n^{1/2} yields depth O~(n^{1.5}). *** More developments Can use similar tricks to get n^{k/2} depth for k-SUM, with O(k) width. A vast improvement was recently found: Theorem [KLM'17] k-SUM has a width-(2k+O(1)) LDT of depth O~(kn). This also implies that Subset-Sum has a LDT of depth O~(n^2). The theorem is somewhat technical, but it has some interesting ideas, which I want to discuss briefly. IDEA 1. Think of k-SUM as a special case of a geometric "hyperplane problem". For an input x in R^n and a finite set S of vectors in R^n, define the vector A_{x,S} in {0,1,-1}^{S} (with components indexed by vectors in S) such that for all v in S, A_{x,S}[v] = sign(sum_i v_i x_i) in {0,-1,1}. That is, A_{x,S} has dimension |S|, and it contains the sign of the inner product of x with every vector in S (positive, negative, or zero). Let W_k subset of {0,1}^n be the set of all Boolean vectors with exactly k ones (hamming weight k). Claim: x in R^n has a kSUM solution <=> the vector A_{x,W_k} contains a zero entry. We will shoot for an LDT that computes the entire vector A_{x,W_k}. So in fact, our LDT will compute the sign of every possible k-SUM solution in O~(kn) depth! IDEA 2. Relate the depth of a LDT for A_{x,S} with a recently-developed concept in learning theory called "inference dimension". Let S-S = {a-b | a, b in S}. Def. For S subset of R^n and v,x in R^n, say that "S can infer v at x" if the sign of is determined by A_{x, S union (S-S)}. This is a little hard to unpack at first, but what we're really saying is: the sign of is determined by some sequence of queries of the form ( >= 0)? or ( >= )? where s,s' are in S. To determine the sign of , we can do this from the signs of for all s in S, and the sorted order of over all s in S. Def. The inference dimension of S is the minimal d such that: for every subset T of S of size >= d, and for every x in R^n, there is a v in T such that T-{v} can infer v at x. Intuition: Inf dim <= d ==> for every possible sample T of >= d vectors from S, there is always some vector v in the sample T s.t. sign can be determined by querying sign of all t in T-{v}, and sorting the numbers for all t in T-{v}. Example: Suppose n=1. Claim: Inf dim of R is 3. Let x in R, and consider any set of three reals, r_1 < r_2 < r_3 in R. We want to show that there is some r_i such that sign() can be determined from the signs of the other two inner products and comparisons between the other two inner products. - If =x*r_1 and have the same sign, then has that sign as well. Can infer r_2 from {r_1,r_3} - If and have different signs, there are several cases: * sign() = -1 and sign() = 0. Thus r_3 = 0 and so r_1 < r_2 < 0 and thus sign() = -1 can be inferred. * sign() = 0 and sign() = 1. Thus r_1 = 0 and so r_3 > r_2 > 0 and thus sign() = 1 can be inferred. * sign() = -1 and sign() = 1. If sign() = -1, then can infer sign() from signs for r_2 and r_3. If sign() = 1, then can infer sign() from signs for r_1 and r_2. Thm 1: Let S be any subset of {0,-1,1}^n. The inference dimension of S is O(n log n). [So, for any subset of S of size >= c*n log n, and every input x, there is at least one vector s in S such that sign is determined by the signs and sorted order of all other inner products of x with S-{s}.] Uses some combinatorics and linear algebra, counting the number of hyperplanes needed to determine the sign of another. Thm 2: For all S subset of R^n with inference dim d, let k = max hamming weight of a vector in S. There is a *randomized* LDT with expected depth O~((d + n)*log |S|) and width 2k such that given any x, the LDT outputs A_{x,S} whp. [That is, there are nodes where coins can be tossed, going to left child or right child with equal probability, and the average depth of a leaf over all random coin tosses is this bound.] Width is 2k, because each node of our LDT will compare the inner products of two s, s' in S with x. So if both s and s' have hamming weight at most k, the total width is at most 2k. Idea of Theorem 2: Maintain a subset S' of S, which is initially all of S. Repeat for O(log |S|) times: - Randomly sample 2d vectors from S' [note: sample is twice the inference dimension of S], query sign() for all s in the sample, and sort for s in the sample. This takes O~(d) depth. - Use this info to infer the signs of for at least *half* of the vectors s in S', in expectation. Remove those vectors s from S'. After O(log |S|) times, we expect that |S'| <= O(1). Cor: k-SUM has a randomized LDT of width 2k and expected depth O((n log n + n log n)*(k log n)) <= O(k n log^2 n). Can be derandomized. Well, LDTs are a non-uniform model, so this should not be too surprising! In the non-uniform setting, BPP = P (i.e., BPP/poly = P/poly). *** Real RAM speed-up for 3-SUM? Yes... best known is a (log^2 n)/polyloglog speed-up by Chan [SODA'18] Roughly speaking, 1. one log-factor is due to packing the outcomes of comparisons (O(log log n)-bit numbers) in words of O(log n) bits, and sorting them with a log-factor speedup 2. another log-factor is due to geometry / hyperplane arrangement ideas, similar (but older) than Kane-Lovett-Moran.