LCA (Binary Lifting) Template

class Lca {
public:
  // 2^10 = 1024
  // 2^14 > 10'005
  static constexpr int kMaxN = 50'005;
  static constexpr int kMaxP = 17;

  Lca (TreeNode* root) {
    node_to_index[nullptr] = 0;
    // index & depth starts from 1
    int index = 1;
    GetDepth(root, 1, nullptr, index);
    n = index;
    
    // node: [1, n]
    memset(alg, 0, sizeof(alg));
    for (int i = 1; i <= n; ++i) {
      alg[i][0] = parent[i];
    }
    
    // i's 2^j father = i's 2^(j-1) father's 2^(j-1) father
    // alg[i][j] = alg[alg[i][j-1]][j-1]
    for (int j = 1; j <= kMaxP; ++j) {
      for (int i = 1; i <= n; ++i) {
        alg[i][j] = alg[alg[i][j-1]][j-1];
      }
    }
  }
  
//   Lca (vector<int> const& parent_input) {
//     n = parent_input.size();
//     memcpy(parent, parent_input.data(), n * sizeof(int));
//     // node: [1, n]
//     memset(alg, 0, sizeof(alg));
//     for (int i = 0; i < n; ++i) {
//       alg[i + 1][0] = parent[i];
//     }
    
//     // i's 2^j father = i's 2^(j-1) father's 2^(j-1) father
//     // alg[i][j] = alg[alg[i][j-1]][j-1]
//     for (int j = 1; j <= kMaxP; ++j) {
//       for (int i = 1; i <= n; ++i) {
//         alg[i][j] = alg[alg[i][j-1]][j-1];
//       }
//     }
//   }
  
//   int GetKParent(int i, int k) {
//     for (int j = kMaxP; j >= 0 && k >= 1; --j) {
//       if (k == (1 << j)) {
//         return alg[i][j];
//       }
//       if (k > (1 << j)) {
//         k -= (1 << j);
//         i = alg[i][j];
//       }
//     }
//     return 0;
//   }
  
  int GetLca(int root0, int root1) {
    if (depth[root0] < depth[root1]) {
      swap(root0, root1);
    }
    // keep jump untill root0 and root1 is at the same level
    for (int j = kMaxP; j >= 0; --j) {
      if (depth[alg[root0][j]] >= depth[root1]) {
        root0 = alg[root0][j];
      }
    }
    if (root0 == root1) {
      return root0;
    }
    // keep jump untill root0 and root1 has the same father
    for (int j = kMaxP; j >= 0; --j) {
      if (alg[root0][j] != 0 && alg[root0][j] != alg[root1][j]) {
        root0 = alg[root0][j];
        root1 = alg[root1][j];
      }
    }
    return alg[root0][0];
  }
  
  TreeNode* GetLca(TreeNode* root0, TreeNode* root1) {
    return index_to_node[GetLca(node_to_index[root0], node_to_index[root1])];
  }
  
  // Returns the distance between a node and its one parent
  int GetDistance(int root, int parent) {
    int res = 0;
    for (int j = kMaxP; j >= 0; --j) {
      if (depth[alg[root][j]] >= depth[parent]) {
        root = alg[root][j];
        res += 1 << j;
      }
    }
    return res;
  }
  
  int GetDistance(TreeNode* root, TreeNode* parent) {
    return GetDistance(node_to_index[root], node_to_index[parent]);
  }
  
private:
  int n;
  
  static unordered_map<TreeNode*, int> node_to_index;
  static unordered_map<int, TreeNode*> index_to_node;
  
  static int depth[kMaxN];
  static int parent[kMaxN];
  
  // i's 2^j father
  static int alg[kMaxN][kMaxP + 1];
  
  void GetDepth(TreeNode* root, int curr_depth, TreeNode* father, int &index) {
    if (root == nullptr) {
      return;
    }
    node_to_index[root] = index;
    index_to_node[index] = root;
    depth[index] = curr_depth;
    parent[index] = node_to_index[father];
    index++;
    GetDepth(root->left, curr_depth + 1, root, index);
    GetDepth(root->right, curr_depth + 1, root, index);
  }
};


int Lca::alg[kMaxN][kMaxP + 1];
int Lca::depth[kMaxN];
int Lca::parent[kMaxN];
unordered_map<TreeNode*, int> Lca::node_to_index;
unordered_map<int, TreeNode*> Lca::index_to_node;

Leave a Reply

Your email address will not be published. Required fields are marked *