Rolling Hash

Summary

Rolling hash is a common technique to compare if two substrings are equal or not. Given O(n) preprocessing time, the comparison becomes O(1) at the best case.

OJ

TrickDifficulty
LC 1316. Distinct Echo SubstringsRolling Hash + string_view or DP in Linear Structure + string_view6 points

Details

We use two large coprime numbers b for base and m for mod. We use a polynomial to calculate the hash value. Typical values are b=1e9+7 and m=1e9+9

    \[f(s) = \sum s[i] \times b^i \pmod m\]

Given a substring s[i, j], it is not difficult to get that the hash value of it is,

    \[f(s[i, j]) = \frac{(f(s[0, j]) - f(s[0, i - 1]) + m)}{b^i} \pmod m\]

We could preprocess tables for all the f[0, i] and b^i in O(n). After that, the complexity of getting a hash value for any substring is O(1).

We then have the following truth,

    \[f(str0)==f(str1) \centernot\implies str0==str1\ (hash\ collision)\]

    \[f(str0)\centernot=f(str1) \implies str0\centernot=sub1\]

namespace {

using ll = unsigned long long;
ll kBase = 31;
ll kMod = 1000'000'007;
constexpr int kMax = 100'005;

// sum(s[i] * kBase^i)
ll hashes[kMax];
// kBase^i
ll base_to_i[kMax];
// inverse(kBase^i)
ll base_to_i_inverse[kMax];

}

ll FastPower(ll val, ll n);

// we have
//  num * x = 1 
//  num ^ (kMod - 1) = 1
// so x = num ^ (kMod - 2)
ll GetInverse(ll val) {
  return FastPower(val, kMod - 2);
}

ll FastPower(ll val, ll n) {
  if (n == 0) {
    return 1;
  }
  if (n == 1) {
    return val;
  }
  if (n % 2 == 0) {
    return FastPower((val * val) % kMod, n / 2);
  }
  return (FastPower((val * val) % kMod, (n - 1) / 2) * val) % kMod;
}

void Preprocess(string_view str) {
  int n = str.size();
  ll curr_base = 1;
  ll curr_sum = 0;
  for (int i = 0; i < n; ++i) {
    curr_sum += (str[i] - 'a') * curr_base; 
    curr_sum %= kMod;
    hashes[i] = curr_sum;
    base_to_i[i] = curr_base;
    curr_base = (curr_base * kBase) % kMod;
  }

  base_to_i_inverse[n - 1] = GetInverse(base_to_i[n - 1]);
  for (int i = n - 2; i >= 0; --i) {
    base_to_i_inverse[i] = (base_to_i_inverse[i + 1] * kBase) % kMod;
  }
}
  
ll GetHash(int l, int r) {
  return ((hashes[r] + kMod - (l - 1 >= 0 ? hashes[l - 1] : 0)) * base_to_i_inverse[l]) % kMod;
}

Two Hash Functions

In LC 1044. Longest Duplicate Substring, you might find out one hash rolling hash function causes collisions. Another technique for rolling hash is to use Two hash functions (two bases and two mods).

    \[f_i(s) = \sum s[i] \times b_i^i \pmod {m_i}\]

    \[f_1(s[i, j]) = \frac{(f_1(s[0, j]) - f_1(s[0, i - 1]) + m_1)}{b_1^i} \pmod {m_1}\]

    \[f_2(s[i, j]) = \frac{(f_2(s[0, j]) - f_2(s[0, i - 1]) + m_2)}{b_2^i} \pmod {m_2}\]

Use f_1 ^ f_2 directly here will cause collisions. Instead, we use a hash set of pairs to store these two values. HashSet in the standard library helps us handle hash collisions, which is sufficient to pass the OJ.

struct PairHash {
    size_t operator() (pair<ll, ll> const &p) const {
        return hash<int>()(p.first) ^ hash<int>()(p.second);
    }
};

unordered_set<pair<int, int>, PairHash> visited;

Another technique you might use here is the Modular Multiplicative Inverse. Please refer to this post.

Rolling Hash Template



class RollingHash {
public:
  using ll = unsigned long long;
  ll hashes[kMaxN];
  ll base_to_i[kMaxN];
  ll base_to_i_inverse[kMaxN];
  
  const int kBase;
  const int kMod;
  
  RollingHash(string_view str, int kBase, int kMod) : kBase(kBase), kMod(kMod) {
    int n = str.size();
    ll curr_base = 1;
    ll curr_sum = 0;
    for (int i = 0; i < n; ++i) {
      curr_sum += (str[i] - 'a') * curr_base; 
      curr_sum %= kMod;
      hashes[i] = curr_sum;
      base_to_i[i] = curr_base;
      curr_base = (curr_base * kBase) % kMod;
    }
    
    base_to_i_inverse[n - 1] = GetInverse(base_to_i[n - 1]);
    for (int i = n - 2; i >= 0; --i) {
      base_to_i_inverse[i] = (base_to_i_inverse[i + 1] * kBase) % kMod;
    }
  }
  
  // s[i, i + len - 1], s[j, j + len - 1]
  bool Equal(int i, int j, int len, string_view str) {
    int hash0 = (GetRelativeHash(i, i + len - 1) * base_to_i[j]) % kMod, hash1 = (GetRelativeHash(j, j + len - 1) * base_to_i[i]) % kMod;
    // if (hash0 != hash1) {
    //     return false;
    // } else {
    //     return str.substr(i, len) == str.substr(j, len);
    // }
    return hash0 == hash1;
  }
  
  int GetHash(int left, int right) {
    return (GetRelativeHash(left, right) * base_to_i_inverse[left]) % kMod;
  }

private:
  // https://oi-wiki.org/math/fermat/
  // https://oi-wiki.org/math/inverse/
  ll GetInverse(int val) {
    return FastPower(val, kMod - 2);
  }
  
  ll FastPower(ll val, unsigned int n) {
    if (n == 0) {
      return 1;
    }
    if (n == 1) {
      return val;
    }
    if (n % 2 == 0) {
      return FastPower((val * val) % kMod, n / 2);
    }
    return (FastPower((val * val) % kMod, (n - 1) / 2) * val) % kMod;
  }
  
  ll GetRelativeHash(int left, int right) {
    return ((ll)hashes[right]  + kMod - ((left >= 1) ? hashes[left - 1] : 0)) % kMod;
  }
};

Leave a Reply

Your email address will not be published. Required fields are marked *