Summary
Rolling hash is a common technique to compare if two substrings are equal or not. Given O(n)
preprocessing time, the comparison becomes O(1)
at the best case.
OJ
Trick | Difficulty | |
---|---|---|
LC 1316. Distinct Echo Substrings | Rolling Hash + string_view or DP in Linear Structure + string_view | 6 points |
Details
We use two large coprime numbers for base and
for mod. We use a polynomial to calculate the hash value. Typical values are
and
Given a substring s[i, j]
, it is not difficult to get that the hash value of it is,
We could preprocess tables for all the and
in
O(n)
. After that, the complexity of getting a hash value for any substring is O(1)
.
We then have the following truth,
namespace {
using ll = unsigned long long;
ll kBase = 31;
ll kMod = 1000'000'007;
constexpr int kMax = 100'005;
// sum(s[i] * kBase^i)
ll hashes[kMax];
// kBase^i
ll base_to_i[kMax];
// inverse(kBase^i)
ll base_to_i_inverse[kMax];
}
ll FastPower(ll val, ll n);
// we have
// num * x = 1
// num ^ (kMod - 1) = 1
// so x = num ^ (kMod - 2)
ll GetInverse(ll val) {
return FastPower(val, kMod - 2);
}
ll FastPower(ll val, ll n) {
if (n == 0) {
return 1;
}
if (n == 1) {
return val;
}
if (n % 2 == 0) {
return FastPower((val * val) % kMod, n / 2);
}
return (FastPower((val * val) % kMod, (n - 1) / 2) * val) % kMod;
}
void Preprocess(string_view str) {
int n = str.size();
ll curr_base = 1;
ll curr_sum = 0;
for (int i = 0; i < n; ++i) {
curr_sum += (str[i] - 'a') * curr_base;
curr_sum %= kMod;
hashes[i] = curr_sum;
base_to_i[i] = curr_base;
curr_base = (curr_base * kBase) % kMod;
}
base_to_i_inverse[n - 1] = GetInverse(base_to_i[n - 1]);
for (int i = n - 2; i >= 0; --i) {
base_to_i_inverse[i] = (base_to_i_inverse[i + 1] * kBase) % kMod;
}
}
ll GetHash(int l, int r) {
return ((hashes[r] + kMod - (l - 1 >= 0 ? hashes[l - 1] : 0)) * base_to_i_inverse[l]) % kMod;
}
Two Hash Functions
In LC 1044. Longest Duplicate Substring, you might find out one hash rolling hash function causes collisions. Another technique for rolling hash is to use Two hash functions (two bases and two mods).
Use f_1 ^ f_2
directly here will cause collisions. Instead, we use a hash set of pairs to store these two values. HashSet in the standard library helps us handle hash collisions, which is sufficient to pass the OJ.
struct PairHash {
size_t operator() (pair<ll, ll> const &p) const {
return hash<int>()(p.first) ^ hash<int>()(p.second);
}
};
unordered_set<pair<int, int>, PairHash> visited;
Another technique you might use here is the Modular Multiplicative Inverse. Please refer to this post.
Rolling Hash Template
class RollingHash {
public:
using ll = unsigned long long;
ll hashes[kMaxN];
ll base_to_i[kMaxN];
ll base_to_i_inverse[kMaxN];
const int kBase;
const int kMod;
RollingHash(string_view str, int kBase, int kMod) : kBase(kBase), kMod(kMod) {
int n = str.size();
ll curr_base = 1;
ll curr_sum = 0;
for (int i = 0; i < n; ++i) {
curr_sum += (str[i] - 'a') * curr_base;
curr_sum %= kMod;
hashes[i] = curr_sum;
base_to_i[i] = curr_base;
curr_base = (curr_base * kBase) % kMod;
}
base_to_i_inverse[n - 1] = GetInverse(base_to_i[n - 1]);
for (int i = n - 2; i >= 0; --i) {
base_to_i_inverse[i] = (base_to_i_inverse[i + 1] * kBase) % kMod;
}
}
// s[i, i + len - 1], s[j, j + len - 1]
bool Equal(int i, int j, int len, string_view str) {
int hash0 = (GetRelativeHash(i, i + len - 1) * base_to_i[j]) % kMod, hash1 = (GetRelativeHash(j, j + len - 1) * base_to_i[i]) % kMod;
// if (hash0 != hash1) {
// return false;
// } else {
// return str.substr(i, len) == str.substr(j, len);
// }
return hash0 == hash1;
}
int GetHash(int left, int right) {
return (GetRelativeHash(left, right) * base_to_i_inverse[left]) % kMod;
}
private:
// https://oi-wiki.org/math/fermat/
// https://oi-wiki.org/math/inverse/
ll GetInverse(int val) {
return FastPower(val, kMod - 2);
}
ll FastPower(ll val, unsigned int n) {
if (n == 0) {
return 1;
}
if (n == 1) {
return val;
}
if (n % 2 == 0) {
return FastPower((val * val) % kMod, n / 2);
}
return (FastPower((val * val) % kMod, (n - 1) / 2) * val) % kMod;
}
ll GetRelativeHash(int left, int right) {
return ((ll)hashes[right] + kMod - ((left >= 1) ? hashes[left - 1] : 0)) % kMod;
}
};