普林斯顿算法课程Part1-week3 归并排序和快速排序

    xiaoxiao2022-07-04  132

    作者: seriouszyx 首发地址:https://seriouszyx.top/ 代码均可在 Github 上找到(求Star)

    归并排序

    归并排序的思想是把数组一分为二,然后再不断将小数组递归地一分为二下去,经过一系列排序再将它们合并起来。

    private static void merge(Comparable[] a, Comparable[] aux, int lo, int mid, int hi) { for (int k = lo; k <= hi; k++) aux[k] = a[k]; int i = lo, j = mid + 1; for (int k = lo; k <= hi; k++) { if (i > mid) a[k] = aux[j++]; else if (j > hi) a[k] = aux[i++]; else if (less(aux[j], aux[i])) a[k] = aux[j++]; else a[k] = aux[i++]; } } private static void sort(Comparable[] a, Comparable[] aux, int lo, int hi) { if (hi <= lo) return; int mid = lo + (hi - lo) / 2; sort(a, aux, lo, mid); sort(a, aux, mid+1, hi); if!less(a[mid + 1], a[mid])return; merge(a, aux, lo, mid, hi); } public static void sort(Comparable[] a) { Comparable[] aux = new Comparable[a.length]; sort(a, aux, 0, a.length - 1); }

    归并排序可用于大量数据的排序,对于 million 和 billion 级别的数据,插入排序难以完成的任务归并排序可能几分钟就完成了。

    对于 N 个元素,归并排序最多需要 NlgN 次比较和 6NlgN 次对数组的访问,并且要使用 N 个空间的辅助数组。

    自底向上的归并排序

    我们将归并排序的过程倒过来看,先将数组分为 2 个元素并将所有组排序,再分为 4 个元素并将所有组排序,… ,直到完成排序。

    public static void sort(Comparable[] a) { int N = a.length; aux = new Comparable[N]; for (int sz = 1; sz < N; sz = sz + sz) for (int lo = 0; lo < N - sz; lo += sz + sz) merge(a, lo, lo + sz - 1, Math.min(lo+sz+sz-1, N-1)); }

    这是一个完全符合工业标准的代码,除了需要额外的存储空间。时间复杂度为 O(NlogN)。

    排序规则

    我们可以实现 Comparator 接口来为排序算法编写不同的排序规则,以插入排序为例:

    public static void sort(Object[] a, Comparator comparator) { int N = a.length; for (int i = 0; i < N; i++) for (int j = i; j > 0 && less(comparator, a[j], a[j-1]); j--) exch(a, j, j - 1); } private static boolean less(Comparator c, Object v, Object w) { return c.compare(v, m) < 0; } private static void exch(Object[] a, int i, int j) { Object swap = a[i]; a[i] = a[j]; a[j] = swap; } public class Student { public static final Comparator<Student> BY_NAME = new ByName(); ... private static class ByName implements Comparator<Student> { public int compare(Student v, Student w) return v.name.compareTo(w.name); } }

    然后可以这样使用排序:

    Arrays.sort(a, Student.BY_NAME);

    使用 Comparator 接口来替代 Comparable 接口的优点就是它支持待排序元素的多种排序规则。

    快速排序

    快速排序广泛运用于系统排序和其他应用中。它也是一个递归过程,与归并排序不同的是,它先进行操作然后再递归,而不是归并排序先进性递归然后再进行 merge。

    算法的思想是先对数组随机打乱,然后每次都把第一个元素放到合适的位置,这个位置左边的元素都比它小,右边的元素都比它大,再将两侧的元素递归操作。

    private static int partition(Comparable[] a, int lo, int hi) { int i = lo, j = hi + 1; while (true) { while (less(a[++i], a[lo])) if (i == hi) break; while (less(a[lo], a[--j])) if (j == lo) break; if (i >= j) break; exch(a, i, j); } exch(a, lo, j); return j; } public static void sort(Comparable[] a) { StdRandom.shuffle(a); sort(a, 0, a.length - 1); } private static void sort(Comparable[] a, int lo, int hi) { if (hi <= lo) return; int j = partition(a, lo, hi); sort(a, lo, j - 1); sort(a, j + 1, hi); }

    事实证明,快速排序比归并排序还要快,他最少需要 NlgN 次比较,最多需要 1/2 N^2 次。对于 N 个元素,快速排序平均需要 1.39NlgN 次比较,不过因为不需要过多的元素的移动,所以实际上它更快一些。其中,随机打乱是为了避免最坏的情况。

    在空间使用上,它不需要额外的空间,所以是常数级别的。

    案例

    快速排序的一个案例是找一个数组中第 k 大的数。

    public static Comparable select(Comparable[] a, int k) { StdRandom.shuffle(a); int lo = 0, hi = a.length - 1; while (hi > lo) { int j = partition(a, lo, hi); if (j < k) lo = j + 1; else if (j > k) hi = j - 1; else return a[k]; } return a[k]; }

    这个解法的时间复杂度是线性的,不过有论文表明它的常数很大,所以在实践中效果不是特别好。

    多个相同键值

    很多时候排序的目的是将相同键值的元素排到一起,处理这种问题不同的排序方法的效率也不同。

    归并排序需要 1/2 NlgN 至 NlgN 次比较。

    快速排序将达到 N^2 除非 partition 过程停止的键值和结果键值相等,所以需要更好的算法实现.

    比较好的一种算法是 Dijkstra 三向切分,它将数组分成了三个部分,是 Dijkstra 的荷兰国旗问题引发的一个思考,即使用三种不同的主键对数组进行排序。

    private static void sort(Comparable[] a, int lo, int hi) { if (hi <= lo) return; int lt = lo, gt = hi; Comparable v = a[lo]; int i = lo; while (i <= gt) { int cmp = a[i].compareTo(v); if (cmp < 0) exch(a, lt++, i++); else if (cmp > 0) exch(a, i, gt--); else i++; } sort(a, lo, lt - 1); sort(a, gt + 1, hi); }

    对于包含大量重复元素的数组,它将排序时间从线性对数级降低到了线性级别。

    系统中的排序

    Java 内置了一种排序方法——Arrays.sort(),这个方法使用两种排序方式共同实现。如果排序的是基本数据类型,就使用快速排序;如果排序的是对象,就使用归并排序。

    因为对于基本类型来说快速排序会使用更少的空间,而且更快;而归并排序能保证 NlogN 的时间复杂度,而且更加稳定。

    在视频的最后,老爷子强调对于不同的应用,要考虑的问题太多了,比如说并行、稳定等等,所以几乎每个重要的系统排序都有一个特定的高效算法,而且目前还有很多算法需要改进。

    最后附上前面提到过的排序方法的总结:

    编程作业:模式识别

    给 n 个不同的点,找出所连线段,每条线段至少包括四个点。

    首先补充完成 Point 类,这部分主要是练习使用 Comparable 和 Comparator 制定排序规则,具体的排序规则文档中有详细的描述。

    public class Point implements Comparable<Point> { private final int x; // x-coordinate of this point private final int y; // y-coordinate of this point /** * Initializes a new point. * * @param x the <em>x</em>-coordinate of the point * @param y the <em>y</em>-coordinate of the point */ public Point(int x, int y) { /* DO NOT MODIFY */ this.x = x; this.y = y; } /** * Draws this point to standard draw. */ public void draw() { /* DO NOT MODIFY */ StdDraw.point(x, y); } /** * Draws the line segment between this point and the specified point * to standard draw. * * @param that the other point */ public void drawTo(Point that) { /* DO NOT MODIFY */ StdDraw.line(this.x, this.y, that.x, that.y); } /** * Returns the slope between this point and the specified point. * Formally, if the two points are (x0, y0) and (x1, y1), then the slope * is (y1 - y0) / (x1 - x0). For completeness, the slope is defined to be * +0.0 if the line segment connecting the two points is horizontal; * Double.POSITIVE_INFINITY if the line segment is vertical; * and Double.NEGATIVE_INFINITY if (x0, y0) and (x1, y1) are equal. * * @param that the other point * @return the slope between this point and the specified point */ public double slopeTo(Point that) { /* YOUR CODE HERE */ if (that == null) throw new NoSuchElementException(); if (this.x == that.x && this.y == that.y) return Double.NEGATIVE_INFINITY; else if (this.x == that.x) return Double.POSITIVE_INFINITY; else if (this.y == that.y) return +0; else return (this.y - that.y) * 1.0 / (this.x - that.x); } /** * Compares two points by y-coordinate, breaking ties by x-coordinate. * Formally, the invoking point (x0, y0) is less than the argument point * (x1, y1) if and only if either y0 < y1 or if y0 = y1 and x0 < x1. * * @param that the other point * @return the value <tt>0</tt> if this point is equal to the argument * point (x0 = x1 and y0 = y1); * a negative integer if this point is less than the argument * point; and a positive integer if this point is greater than the * argument point */ public int compareTo(Point that) { /* YOUR CODE HERE */ if (that == null) throw new NoSuchElementException(); if (this.x == that.x && this.y == that.y) return 0; if (this.y < that.y || (this.y == that.y && this.x < that.x)) return -1; return 1; } /** * Compares two points by the slope they make with this point. * The slope is defined as in the slopeTo() method. * * @return the Comparator that defines this ordering on points */ public Comparator<Point> slopeOrder() { /* YOUR CODE HERE */ return new SlopeCompare(); } private class SlopeCompare implements Comparator<Point> { @Override public int compare(Point o1, Point o2) { if (o1 == null || o2 == null) throw new NoSuchElementException(); if (slopeTo(o1) == Double.NEGATIVE_INFINITY && slopeTo(o2) == Double.NEGATIVE_INFINITY) return 0; else if (slopeTo(o1) == Double.POSITIVE_INFINITY && slopeTo(o2) == Double.POSITIVE_INFINITY) return 0; else if (slopeTo(o1) == Double.POSITIVE_INFINITY && slopeTo(o2) == Double.NEGATIVE_INFINITY) return 1; else if (slopeTo(o1) == Double.NEGATIVE_INFINITY && slopeTo(o2) == Double.POSITIVE_INFINITY) return -1; else if (slopeTo(o1) - slopeTo(o2) > 0) return 1; else if (slopeTo(o1) - slopeTo(o2) < 0) return -1; // return slopeTo(o1) - slopeTo(o2) < 0 ? -1 : 1; return 0; } } /** * Returns a string representation of this point. * This method is provide for debugging; * your program should not rely on the format of the string representation. * * @return a string representation of this point */ public String toString() { /* DO NOT MODIFY */ return "(" + x + ", " + y + ")"; } /** * Unit tests the Point data type. */ public static void main(String[] args) { /* YOUR CODE HERE */ Point p1 = new Point(0, 10); Point p2 = new Point(10, 0); System.out.println(p1.slopeTo(p2) == p2.slopeTo(p1)); Point p3 = new Point(0, 20); System.out.println(p3.slopeOrder().compare(p1, p2)); } }

    然后根据给出的点求所能组成的线段,线段只包含四个点,由两端的点表示,这个方法是暴力方法,4次方的时间复杂度。

    public class BruteCollinearPoints { /** Record the linesegments */ private ArrayList<LineSegment> list; /** * find all line segments containing 4 points * @param points */ public BruteCollinearPoints(Point[] points) { if (points == null) throw new IllegalArgumentException(); for (Point p : points) { if (p == null) throw new IllegalArgumentException(); } for (int i = 0; i < points.length - 1; i++) { for (int j = i + 1; j < points.length; j++) { if (points[i].compareTo(points[j]) == 0) throw new IllegalArgumentException(); } } list = new ArrayList<>(); int N = points.length; for (int i = 0; i < N; i++) { for (int j = i + 1; j < N; j++) { for (int k = j + 1; k < N; k++) { for (int t = k + 1; t < N; t++) { if (points[i].slopeTo(points[j]) == points[i].slopeTo(points[k]) && points[i].slopeTo(points[k]) == points[i].slopeTo(points[t])) addLineSegment(points, i, j, k, t); } } } } } /** * Add the line segment to list * @param points * @param i * @param j * @param k * @param t */ private void addLineSegment(Point[] points, int i, int j, int k, int t) { Point[] ps = new Point[]{points[i], points[j], points[k], points[t]}; Point min = ps[0], max = ps[0]; for (int index = 1; index < ps.length; index++) { if (min.compareTo(ps[index]) > 0) min = ps[index]; if (max.compareTo(ps[index]) < 0) max = ps[index]; } list.add(new LineSegment(min, max)); } /** * the number of line segments * @return */ public int numberOfSegments() { return list.size(); } /** * the line segments * @return */ public LineSegment[] segments() { LineSegment[] ans = new LineSegment[list.size()]; for (int i = 0; i < list.size(); i++) { ans[i] = list.get(i); } return ans; } public static void main(String[] args) { } }

    然后实现高效算法,这里就需要使用前面提到的比较规则,先使用快排将点集排序,取最小的点跟其他点的斜率比,如果达到四个点及以上斜率相同,则记录到数组中。

    public class FastCollinearPoints { /** Record the linesegments */ private ArrayList<LineSegment> list; /** * find all line segments containing 4 or more points * @param points */ public FastCollinearPoints(Point[] points) { if (points == null) throw new IllegalArgumentException(); for (Point p : points) { if (p == null) throw new IllegalArgumentException(); } for (int i = 0; i < points.length - 1; i++) { for (int j = i + 1; j < points.length; j++) { if (points[i].compareTo(points[j]) == 0) throw new IllegalArgumentException(); } } list = new ArrayList<>(); int N = points.length; Arrays.sort(points); for (int i = 0; i < N - 1; i++) { /** get the smallest point */ Arrays.sort(points); Point min = points[i]; /** sort as the points' slope */ Arrays.sort(points, i, N, points[i].slopeOrder()); Point max = null; int count = 0; for (int j = i + 1; j < N - 1; j++) { if (min.slopeTo(points[j]) == min.slopeTo(points[j + 1])) { count++; max = points[j + 1]; } else if (count != 2) { count = 0; } if (count >= 2) { count = 0; list.add(new LineSegment(min, max)); } } } } /** * the number of line segments * @return */ public int numberOfSegments() { return list.size(); } /** * the line segments * @return */ public LineSegment[] segments() { LineSegment[] ans = new LineSegment[list.size()]; for (int i = 0; i < list.size(); i++) { ans[i] = list.get(i); } return ans; } public static void main(String[] args) { In in = new In(args[0]); int n = in.readInt(); Point[] points = new Point[n]; for (int i = 0; i < n; i++) { int x = in.readInt(); int y = in.readInt(); points[i] = new Point(x, y); } StdDraw.enableDoubleBuffering(); StdDraw.setXscale(0, 32768); StdDraw.setYscale(0, 32768); for (Point p : points) { p.draw(); } StdDraw.show(); FastCollinearPoints collinear = new FastCollinearPoints(points); for (LineSegment segment : collinear.segments()) { StdOut.println(segment); segment.draw(); } StdDraw.show(); } }

    这次作业目前只拿了88分,应该对于大规模的数据仍有不足。

    对了不得不说这门课的 PA 真的有趣:

    最新回复(0)