//符号表是一种存储键值对的数据结构, 它将一个键Key和一个值Value关联在一起存储,以后可以通过一个键值从符号表中找到其关联的值。 //符号表一般用二叉查找树和哈希表实现,像java库里的HashMap/TreeMap //二叉查找树中每个节点都有一个Comparable的Key值和与Key关联的Value值,每个节点都有一个父节点(根节点除外),0/1/2个子节点, //每个节点的Key值比它左子树中任意节点的Key值都大,比它右子树中任意节点的Key值都小,每个节点的Key值互不相同,示意图如下:
代码实现如下 public class BST<Key extends Comparable<Key>, Value > { private class Node { private Key key; // sorted by key private Value val; // associated data private Node left, right; // left and right subtrees private int size; // number of nodes in subtree
public Node(Key key, Value val, int size) { this.key = key; this.val = val; this.size = size; } } private Node root; /** * Initializes an empty symbol table. */ public BST() { }
/** * Returns true if this symbol table is empty. * @return {@code true} if this symbol table is empty; {@code false} otherwise */ public boolean isEmpty() { return size() == 0; }
/** * Returns the number of key-value pairs in this symbol table. * @return the number of key-value pairs in this symbol table */ public int size() { return size(root); }
// 获取以x作为根节点的子树节点的个数 private int size(Node x) { if (x == null) return 0; else return x.size; } /** * Does this symbol table contain the given key? * * @param key the key * @return {@code true} if this symbol table contains {@code key} and * {@code false} otherwise * @throws IllegalArgumentException if {@code key} is {@code null} */ public boolean contains(Key key) { if (key == null) throw new IllegalArgumentException("argument to contains() is null"); return get(key) != null; } //Ex3.2.13 get非递归版本实现 Value get2(Key key) { Node x = root; while (x != null) { int cmp = key.compareTo(x.key); if (cmp < 0) x = x.left; else if (cmp > 0) x = x.right; else return x.val; } return null; } /** * Returns the value associated with the given key. * * @param key the key * @return the value associated with the given key if the key is in the symbol table * and {@code null} if the key is not in the symbol table * @throws IllegalArgumentException if {@code key} is {@code null} */ public Value get(Key key) { //从根节点查找key return get(root, key); } //采用递归的方式从x节点查找key,如果x节点为空说明节点都查找完了未命中, //如果key < x.key则从x左子树的根节点开始查找,如果key > x.key则从x右子树的根节点开始查找 //如果key = x.key,命中 private Value get(Node x, Key key) { if (key == null) throw new IllegalArgumentException("calls get() with a null key"); if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp < 0) return get(x.left, key); else if (cmp > 0) return get(x.right, key); else return x.val; } //Ex3.2.13 put非递归版本实现,计算size可以root到key遍历下树 public void put2(Key key, Value val) { Node z = new Node(key, val, 0); if (root == null) { root = z; return; }
Node parent = null, x = root; while (x != null) { parent = x; int cmp = key.compareTo(x.key); if (cmp < 0) x = x.left; else if (cmp > 0) x = x.right; else { x.val = val; return; } } int cmp = key.compareTo(parent.key); if (cmp < 0) parent.left = z; else parent.right = z; } /** * Inserts the specified key-value pair into the symbol table, overwriting the old * value with the new value if the symbol table already contains the specified key. * Deletes the specified key (and its associated value) from this symbol table * if the specified value is {@code null}. * * @param key the key * @param val the value * @throws IllegalArgumentException if {@code key} is {@code null} */ public void put(Key key, Value val) { if (key == null) throw new IllegalArgumentException("calls put() with a null key"); //有的实现不一定会删除key if (val == null) { delete(key); return; } //从根节点开始插入并返回值给根节点,因为插入过程中根节点可能改变 root = put(root, key, val); } //从节点x进行插入,如果x为空则用新节点占据x的位置,否则把key与x.key进行比较 //如果key < x.key,递归操作从x的左节点x.left进行插入,并更新x.left //如果key > x.key,递归操作从x的右节点x.right进行插入,并更新x.right //如果key = x.key,命中,更新x.val //最后更新x.size因为插入了一个节点 private Node put(Node x, Key key, Value val) { if (x == null) return new Node(key, val, 1); int cmp = key.compareTo(x.key); if (cmp < 0) x.left = put(x.left, key, val); if (cmp > 0) x.right = put(x.right, key, val); else x.val = val; x.size = 1+ size(x.left) + size(x.right); return x; } // 插入和查找操作的时间跟树的高度有关,最好的情况下,一棵含有N个节点的树是完全平衡的,每个叶子节点与根节点的距离~log(N) //最坏情况下叶子节点与根节点的距离为N, 平均情况下插入和查找操作的时间复杂度是O(logN) /** * Returns the smallest key in the symbol table. * * @return the smallest key in the symbol table * @throws NoSuchElementException if the symbol table is empty */ public Key min() { if (isEmpty()) throw new NoSuchElementException("calls min() with empty symbol table"); //从根节点查找最小键 return min(root).key; } //从x节点查找最小键,小的节点都在x的左子树,如果x.left为空,则x就是最小的键,否则递归从x的左子树查找最小键 private Node min(Node x) { if (x.left == null) return x; else return min(x.left); } /** * Returns the largest key in the symbol table. * * @return the largest key in the symbol table * @throws NoSuchElementException if the symbol table is empty */ public Key max() { if (isEmpty()) throw new NoSuchElementException("calls max() with empty symbol table"); return max(root).key; }
private Node max(Node x) { if (x.right == null) return x; else return max(x.right); } /** * Removes the smallest key and associated value from the symbol table. * * @throws NoSuchElementException if the symbol table is empty */ public void deleteMin() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow"); //从root为跟节点的树进行删除,返回删出后树的根节点 root = deleteMin(root); //assert check(); } //从以x为根节点的子树中删除最小值,返回删除后该子树的根节点 //如果x.left为空说明x是最小值,删除x,返回x.right作为子树的根节点 //如果x.left不为空,从以x.left为根节点的子树中删除最小值,返回删除后该子树的根节点,付给x.left //x.left的子树删除了一个节点,更新x.size private Node deleteMin(Node x) { if(x.left == null) return x.right; x.left = deleteMin(x.left); x.size = 1+ size(x.left) + size(x.right); return x; } /** * Removes the largest key and associated value from the symbol table. * * @throws NoSuchElementException if the symbol table is empty */ public void deleteMax() { if (isEmpty()) throw new NoSuchElementException("Symbol table underflow"); root = deleteMax(root); //assert check(); }
private Node deleteMax(Node x) { if (x.right == null) return x.left; x.right = deleteMax(x.right); x.size = size(x.left) + size(x.right) + 1; return x; } /** * Removes the specified key and its associated value from this symbol table * (if the key is in this symbol table). * * @param key the key * @throws IllegalArgumentException if {@code key} is {@code null} */ public void delete(Key key) { if (key == null) throw new IllegalArgumentException("calls delete() with a null key"); //从root为跟节点的树进行删除,返回删出后树的根节点 root = delete(root, key); //assert check(); } //从以x为根节点的子树中删除指定键,返回删除后该子树的根节点 private Node delete(Node x, Key key) { //如果x为null,返回null if (x == null) return null; int cmp = key.compareTo(x.key); //如果key < x.key,进行递归操作从以x.left为根节点的子树中删除指定键,返回删除后该子树的根节点付给x.left //如果key > x.key,进行递归操作从以x.right为根节点的子树中删除指定键,返回删除后该子树的根节点付给x.right if (cmp < 0) x.left = delete(x.left, key); else if (cmp > 0) x.right = delete(x.right, key); //否则命中删除x节点,删除策略如下 //如果x.right为空,删除x后只剩x的左子树,所以返回x.left //如果x.left为空,删除x后只剩x的右子树,所以返回x.right //如果x的左右节点都不为空,删除x节点后应该从x.right中找出最小值替代x作为跟节点 //所以用t保存原来的子树, 把t.right中最小值付给x,然后删除t.right中最小值,新的右子树付给x.right. //再把t.left付给x.left,最后计算x.size,返回x,参考上图 else { if (x.right == null) return x.left; if (x.left == null) return x.right; Node t = x; x = min(t.right); x.right = deleteMin(t.right); x.left = t.left; } x.size = size(x.left) + size(x.right) + 1; return x; } /** * Returns the largest key in the symbol table less than or equal to {@code key}. * * @param key the key * @return the largest key in the symbol table less than or equal to {@code key} * @throws NoSuchElementException if there is no such key * @throws IllegalArgumentException if {@code key} is {@code null} */ //找出小于等于键key的节点中key值最大的 public Key floor(Key key) { if (key == null) throw new IllegalArgumentException("argument to floor() is null"); if (isEmpty()) throw new NoSuchElementException("calls floor() with empty symbol table"); Node x = floor(root, key); if (x == null) return null; else return x.key; } //从根节点为x的子树中查找小于等于key的节点中最大key的节点 //用key与x.key对比,如果key=x.key,则返回x,x就是最大的key节点 //如果key<x.key,则x和x.right都不能<=key,所以在x.left中递归继续查找 //如果key>x.key,并不能说明x就是最大值,在x.right中可能还有小于等于key的节点,递归在x.right中查找,如果在x.right中没有找到则返回x. private Node floor(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp < 0) return floor(x.left, key); Node t = floor(x.right, key); if (t != null) return t; else return x; }
public Key floor2(Key key) { return floor2(root, key, null); }
private Key floor2(Node x, Key key, Key best) { if (x == null) return best; int cmp = key.compareTo(x.key); if (cmp < 0) return floor2(x.left, key, best); else if (cmp > 0) return floor2(x.right, key, x.key); else return x.key; }
/** * Returns the smallest key in the symbol table greater than or equal to {@code key}. * * @param key the key * @return the smallest key in the symbol table greater than or equal to {@code key} * @throws NoSuchElementException if there is no such key * @throws IllegalArgumentException if {@code key} is {@code null} */ public Key ceiling(Key key) { if (key == null) throw new IllegalArgumentException("argument to ceiling() is null"); if (isEmpty()) throw new NoSuchElementException("calls ceiling() with empty symbol table"); Node x = ceiling(root, key); if (x == null) return null; else return x.key; }
private Node ceiling(Node x, Key key) { if (x == null) return null; int cmp = key.compareTo(x.key); if (cmp == 0) return x; if (cmp < 0) { Node t = ceiling(x.left, key); if (t != null) return t; else return x; } return ceiling(x.right, key); } /** * Return the key in the symbol table whose rank is {@code k}. * This is the (k+1)st smallest key in the symbol table. * * @param k the order statistic * @return the key in the symbol table of rank {@code k} * @throws IllegalArgumentException unless {@code k} is between 0 and * <em>n</em>–1 */ //找出第k大的节点x,k从0开始,则x前面有k个节点,前面在put,delete操作中记录节点的size主要就为了支持select和rank操作,这样比较麻烦和容易出错 //如果符号表实现中不需要select和rank操作则可以去掉记录节点size的操作 public Key select(int k) { if (k < 0 || k >= size()) { throw new IllegalArgumentException("argument to select() is invalid: " + k); } Node x = select(root, k); return x.key; }
// Return key of rank k. private Node select(Node x, int k) { if (x == null) return null; int t = size(x.left); if (t > k) return select(x.left, k); else if (t < k) return select(x.right, k-t-1); else return x; }
/** * Return the number of keys in the symbol table strictly less than {@code key}. * * @param key the key * @return the number of keys in the symbol table strictly less than {@code key} * @throws IllegalArgumentException if {@code key} is {@code null} */ public int rank(Key key) { if (key == null) throw new IllegalArgumentException("argument to rank() is null"); return rank(key, root); }
// Number of keys in the subtree less than key. private int rank(Key key, Node x) { if (x == null) return 0; int cmp = key.compareTo(x.key); if (cmp < 0) return rank(key, x.left); else if (cmp > 0) return 1 + size(x.left) + rank(key, x.right); else return size(x.left); } //Ex3.2.13 keys非递归版本实现 public Iterable<Key> keys2() { Stack<Node> stack = new Stack<Node>(); Queue<Key> queue = new Queue<Key>(); Node x = root; while (x != null || !stack.isEmpty()) { if (x != null) { stack.push(x); x = x.left; } else { x = stack.pop(); queue.enqueue(x.key); x = x.right; } } return queue; } /** * Returns all keys in the symbol table as an {@code Iterable}. * To iterate over all of the keys in the symbol table named {@code st}, * use the foreach notation: {@code for (Key key : st.keys())}. * * @return all keys in the symbol table */ public Iterable<Key> keys() { if (isEmpty()) return new Queue<Key>(); return keys(min(), max()); }
/** * Returns all keys in the symbol table in the given range, * as an {@code Iterable}. * * @param lo minimum endpoint * @param hi maximum endpoint * @return all keys in the symbol table between {@code lo} * (inclusive) and {@code hi} (inclusive) * @throws IllegalArgumentException if either {@code lo} or {@code hi} * is {@code null} */ public Iterable<Key> keys(Key lo, Key hi) { if (lo == null) throw new IllegalArgumentException("first argument to keys() is null"); if (hi == null) throw new IllegalArgumentException("second argument to keys() is null");
Queue<Key> queue = new Queue<Key>(); keys(root, queue, lo, hi); return queue; } //采用中序遍历的思想,先遍历x.left,打印x,再遍历x.right,把满足条件的节点存入队列Queue private void keys(Node x, Queue<Key> queue, Key lo, Key hi) { if (x == null) return; int cmplo = lo.compareTo(x.key); int cmphi = hi.compareTo(x.key); //如果x.key > lo,则x.left中可能还有大于lo的节点,递归遍历x.left,如果x.key <= lo则x.left中节点必然小于lo,就不需再遍历x.left if (cmplo < 0) keys(x.left, queue, lo, hi); //满足条件存入queue if (cmplo <= 0 && cmphi >= 0) queue.enqueue(x.key); //如果x.key < hi,则x.right中可能还有大于hi的节点,递归遍历x.right,如果x.key >= hi则x.right中节点必然大于hi,就不需再遍历x.right if (cmphi > 0) keys(x.right, queue, lo, hi); } /** * Returns the height of the BST (for debugging). * * @return the height of the BST (a 1-node tree has height 0) */ //ex 3.2.6 public int height() { return height(root); } private int height(Node x) { if (x == null) return -1; return 1 + Math.max(height(x.left), height(x.right)); } /** * Returns the keys in the BST in level order (for debugging). * * @return the keys in the BST in level order traversal */ public Iterable<Key> levelOrder() { Queue<Key> keys = new Queue<Key>(); Queue<Node> queue = new Queue<Node>(); queue.enqueue(root); while (!queue.isEmpty()) { Node x = queue.dequeue(); if (x == null) continue; keys.enqueue(x.key); queue.enqueue(x.left); queue.enqueue(x.right); } return keys; }
/************************************************************************* * Check integrity of BST data structure. ***************************************************************************/ private boolean check() { if (!isBST()) StdOut.println("Not in symmetric order"); if (!isSizeConsistent()) StdOut.println("Subtree counts not consistent"); if (!isRankConsistent()) StdOut.println("Ranks not consistent"); return isBST() && isSizeConsistent() && isRankConsistent(); }
// does this binary tree satisfy symmetric order? // Note: this test also ensures that data structure is a binary tree since order is strict private boolean isBST() { return isBST(root, null, null); }
// is the tree rooted at x a BST with all keys strictly between min and max // (if min or max is null, treat as empty constraint) // Credit: Bob Dondero's elegant solution private boolean isBST(Node x, Key min, Key max) { if (x == null) return true; if (min != null && x.key.compareTo(min) <= 0) return false; if (max != null && x.key.compareTo(max) >= 0) return false; return isBST(x.left, min, x.key) && isBST(x.right, x.key, max); }
// are the size fields correct? private boolean isSizeConsistent() { return isSizeConsistent(root); } private boolean isSizeConsistent(Node x) { if (x == null) return true; if (x.size != size(x.left) + size(x.right) + 1) return false; return isSizeConsistent(x.left) && isSizeConsistent(x.right); }
// check that ranks are consistent private boolean isRankConsistent() { for (int i = 0; i < size(); i++) if (i != rank(select(i))) return false; for (Key key : keys()) if (key.compareTo(select(rank(key))) != 0) return false; return true; }
}