Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use a Trie data structure to find the sum of LCPs for all possible substrings?

Problem Description: Fun with Strings References: Fun With Strings

Based on the problem description, a naive approach to find sum of length of LCP for all possible substrings (for a given string) is as follows :

#include <cstring>
#include <iostream>

using std::cout;
using std::cin;
using std::endl;
using std::string;

int lcp(string str1, string str2) 
{ 
    string result; 
    int n1 = str1.length(), n2 = str2.length(); 

    // Compare str1 and str2 
    for (int i=0, j=0; i<=n1-1 && j<=n2-1; i++,j++) 
    { 
        if (str1[i] != str2[j]) 
            break; 
        result.push_back(str1[i]); 
    } 

    return (result.length()); 
} 

int main()
{
    string s;
    cin>>s;
    int sum = 0;

    for(int i = 0; i < s.length(); i++)
        for(int j = i; j < s.length(); j++)
            for(int k = 0; k < s.length(); k++)
                for(int l = k; l < s.length(); l++)
                    sum += lcp(s.substr(i,j - i + 1),s.substr(k,l - k + 1));
    cout<<sum<<endl;     
    return 0;
}

Based on further reading and research on LCP's, I found this document which specifies a way to efficiently find a LCP using an Advanced Data Structure called Tries. I implemented a Trie and a Compressed Trie (Suffix Tree) as follows:

#include <iostream>
#include <cstring>

using std::cout;
using std::cin;
using std::endl;
using std::string;
const int ALPHA_SIZE = 26;

struct TrieNode
{
    struct TrieNode *children[ALPHA_SIZE];
    string label;
    bool isEndOfWord;
};
typedef struct TrieNode Trie;

Trie *getNode(void)
{
    Trie *parent = new Trie;
    parent->isEndOfWord = false;
    parent->label = "";
    for(int i = 0; i <ALPHA_SIZE; i++)
        parent->children[i] = NULL;

    return parent;
}

void insert(Trie *root, string key)
{
    Trie *temp = root;

    for(int i = 0; i < key.length(); i++)
    {
        int index = key[i] - 'a';
        if(!temp->children[index])
        {
            temp->children[index] = getNode();
            temp->children[index]->label = key[i];
        }
        temp = temp->children[index];
        temp->isEndOfWord = false;
    }
    temp->isEndOfWord = true;
}

int countChildren(Trie *node, int *index)
{
    int count = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(node->children[i] != NULL)
        {
            count++;
            *index = i;
        }
    }
    return count;
}

void display(Trie *root)
{
    Trie *temp = root;
    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i] != NULL)
        {
            cout<<temp->label<<"->"<<temp->children[i]->label<<endl;
            if(!temp->isEndOfWord)
                display(temp->children[i]);
        }
    }
}

void compress(Trie *root)
{
    Trie *temp = root;
    int index = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i])
        {
            Trie *child = temp->children[i];

            if(!child->isEndOfWord)
            {
                if(countChildren(child,&index) >= 2)
                {
                    compress(child);
                }
                else if(countChildren(child,&index) == 1)
                {
                    while(countChildren(child,&index) < 2 and countChildren(child,&index) > 0)
                    {
                        Trie *sub_child = child->children[index];

                        child->label = child->label + sub_child->label;
                        child->isEndOfWord = sub_child->isEndOfWord;
                        memcpy(child->children,sub_child->children,sizeof(sub_child->children));

                        delete(sub_child);
                    }
                    compress(child);
                }
            }
        }
    }
}

bool search(Trie *root, string key)
{
    Trie *temp = root;

    for(int i = 0; i < key.length(); i++)
    {
        int index = key[i] - 'a';
        if(!temp->children[index])
            return false;
        temp = temp->children[index];
    }
    return (temp != NULL && temp->isEndOfWord);
}

int main()
{
    string input;
    cin>>input;

    Trie *root = getNode();

    for(int i = 0; i < input.length(); i++)
        for(int j = i; j < input.length(); j++)
        {
            cout<<"Substring : "<<input.substr(i,j - i + 1)<<endl;
            insert(root, input.substr(i,j - i + 1));
        }

    cout<<"DISPLAY"<<endl;
    display(root);

    compress(root);
    cout<<"AFTER COMPRESSION"<<endl;
    display(root);

    return 0;
}

My question, is how do I proceed to find the length of the LCP. I can get the LCP by getting the label field at the branching node, but how do I count the length of LCP's for all possible substrings ?

One way I thought of was to some how use the branching node, its label field which holds the LCP, and the branching node's children to find sum of all LCP's length (Lowest Common Ancestor ?). But I am still confused. How do I proceed further ?

Note: It is also possible that my approach to this problem is wrong, so please suggest other methods too for this problem (considering time and space complexity).

Link to similar unanswered questions:

  • sum of LCP of all pairs of substrings of a given string
  • Longest common prefix length of all substrings and a string

References for Code and Theory:

  • LCP
  • Trie
  • Compressed Trie

Update1:

Based on @Adarsh Anurag's answer, I have come up with following implementation with the help of trie data structure,

#include <iostream>
#include <cstring>
#include <stack>

using std::cout;
using std::cin;
using std::endl;
using std::string;
using std::stack;

const int ALPHA_SIZE = 26;
int sum = 0;
stack <int> lcp;

struct TrieNode
{
    struct TrieNode *children[ALPHA_SIZE];
    string label;
    int count;
};
typedef struct TrieNode Trie;

Trie *getNode(void)
{
    Trie *parent = new Trie;
    parent->count = 0;
    parent->label = "";
    for(int i = 0; i <ALPHA_SIZE; i++)
        parent->children[i] = NULL;

    return parent;
}

void insert(Trie *root, string key)
{
    Trie *temp = root;

    for(int i = 0; i < key.length(); i++)
    {
        int index = key[i] - 'a';
        if(!temp->children[index])
        {
            temp->children[index] = getNode();
            temp->children[index]->label = key[i];
        }
        temp = temp->children[index];
    }
    temp->count++;
}

int countChildren(Trie *node, int *index)
{
    int count = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(node->children[i] != NULL)
        {
            count++;
            *index = i;
        }
    }
    return count;
}

void display(Trie *root)
{
    Trie *temp = root;
    int index = 0;
    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i] != NULL)
        {
            cout<<temp->label<<"->"<<temp->children[i]->label<<endl;
            cout<<"CountOfChildren:"<<countChildren(temp,&index)<<endl;
            cout<<"Counter:"<<temp->children[i]->count<<endl;

            display(temp->children[i]);
        }
    }
}

void lcp_sum(Trie *root,int counter,string lcp_label)
{
    Trie *temp = root;
    int index = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i])
        {
            Trie *child = temp->children[i];

            if(lcp.empty())
            {
                lcp_label = child->label;
                counter = 0;

                lcp.push(child->count*lcp_label.length());
                sum += lcp.top();
                counter += 1;
            }
            else
            {
                lcp_label = lcp_label + child->label;
                stack <int> temp = lcp;

                while(!temp.empty())
                {
                    sum = sum + 2 * temp.top() * child->count;
                    temp.pop();
                }

                lcp.push(child->count*lcp_label.length());
                sum += lcp.top();
                counter += 1;
            }

            if(countChildren(child,&index) > 1)
            {
                lcp_sum(child,0,lcp_label);
            }
            else if (countChildren(child,&index) == 1)
                lcp_sum(child,counter,lcp_label);
            else
            {
                while(counter-- && !lcp.empty())
                    lcp.pop();
            }
        }
    }
}

int main()
{
    string input;
    cin>>input;

    Trie *root = getNode();

    for(int i = 0; i < input.length(); i++)
        for(int j = i; j < input.length(); j++)
        {
            cout<<"Substring : "<<input.substr(i,j - i + 1)<<endl;
            insert(root, input.substr(i,j - i + 1));
            display(root);
        }

    cout<<"DISPLAY"<<endl;
    display(root);

    cout<<"COUNT"<<endl;
    lcp_sum(root,0,"");
    cout<<sum<<endl;

    return 0;
}

From the Trie structure, I have removed the variable isEndOfWordand instead replaced it with a counter. This variable keeps track of duplicate substrings which should help in counting LCP's for strings with duplicate characters. However, the above implementation works only for strings with distinct characters. I have tried implementing the method suggested by @Adarsh for duplicate characters but does not satisfy any test case.

Update2:

Based on further updated answer from @Adarsh and "trial and error" with different testcases, I seem to have progressed a bit for duplicate characters, however it still does not work as expected. Here is the implementation with comments,

// LCP : Longest Common Prefix
// DFS : Depth First Search

#include <iostream>
#include <cstring>
#include <stack>
#include <queue>

using std::cout;
using std::cin;
using std::endl;
using std::string;
using std::stack;
using std::queue;

const int ALPHA_SIZE = 26;
int sum = 0;     // Global variable for LCP sum
stack <int> lcp; //Keeps track of current LCP

// Trie Data Structure Implementation (See References Section)
struct TrieNode
{
    struct TrieNode *children[ALPHA_SIZE]; // Search space can be further reduced by keeping track of required indicies
    string label;
    int count; // Keeps track of repeat substrings
};
typedef struct TrieNode Trie;

Trie *getNode(void)
{
    Trie *parent = new Trie;
    parent->count = 0;
    parent->label = ""; // Root Label at level 0 is an empty string
    for(int i = 0; i <ALPHA_SIZE; i++)
        parent->children[i] = NULL;

    return parent;
}

void insert(Trie *root, string key)
{
    Trie *temp = root;

    for(int i = 0; i < key.length(); i++)
    {
        int index = key[i] - 'a';   // Lowercase alphabets only
        if(!temp->children[index])
        {
            temp->children[index] = getNode();
            temp->children[index]->label = key[i]; // Label represents the character being inserted into the node
        }
        temp = temp->children[index];
    }
    temp->count++;
}

// Returns the count of child nodes for a given node
int countChildren(Trie *node, int *index)
{
    int count = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(node->children[i] != NULL)
        {
            count++;
            *index = i; //Not required for this problem, used in compressed trie implementation
        }
    }
    return count;
}

// Displays the Trie in DFS manner
void display(Trie *root)
{
    Trie *temp = root;
    int index = 0;
    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i] != NULL)
        {
            cout<<temp->label<<"->"<<temp->children[i]->label<<endl; // Display in this format : Root->Child
            cout<<"CountOfChildren:"<<countChildren(temp,&index)<<endl; // Count of Child nodes for Root
            cout<<"Counter:"<<temp->children[i]->count<<endl; // Count of repeat substrings for a given node
            display(temp->children[i]);
        }
    }
}

/* COMPRESSED TRIE IMPLEMENTATION
void compress(Trie *root)
{
    Trie *temp = root;
    int index = 0;

    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        if(temp->children[i])
        {
            Trie *child = temp->children[i];

            //if(!child->isEndOfWord)
            {
                if(countChildren(child,&index) >= 2)
                {
                    compress(child);
                }
                else if(countChildren(child,&index) == 1)
                {
                    while(countChildren(child,&index) < 2 and countChildren(child,&index) > 0)
                    {
                        Trie *sub_child = child->children[index];

                        child->label = child->label + sub_child->label;
                        //child->isEndOfWord = sub_child->isEndOfWord;
                        memcpy(child->children,sub_child->children,sizeof(sub_child->children));

                        delete(sub_child);
                    }
                    compress(child);
                }
            }
        }
    }
}
*/

// Calculate LCP Sum recursively
void lcp_sum(Trie *root,int *counter,string lcp_label,queue <int> *s_count)
{
    Trie *temp = root;
    int index = 0;

    // Traverse through this root's children array, to find child nodes
    for(int i = 0; i < ALPHA_SIZE; i++)
    {
        // If child nodes found, then ...
        if(temp->children[i] != NULL)
        {
            Trie *child = temp->children[i];

            // Check if LCP stack is empty
            if(lcp.empty())
            {
                lcp_label = child->label;   // Set LCP label as Child's label
                *counter = 0;               // To make sure counter is not -1 during recursion

                /*
                    * To include LCP of repeat substrings, multiply the count variable with current LCP Label's length
                    * Push this to a stack called lcp
                */
                lcp.push(child->count*lcp_label.length());

                // Add LCP for (a,a)
                sum += lcp.top() * child->count; // Formula to calculate sum for repeat substrings : (child->count) ^ 2 * LCP Label's Length
                *counter += 1; // Increment counter, this is used further to pop elements from the stack lcp, when a branching node is encountered
            }
            else
            {
                lcp_label = lcp_label + child->label; // If not empty, then add Child's label to LCP label
                stack <int> temp = lcp; // Temporary Stack

                /*
                    To calculate LCP for different combinations of substrings,
                    2 -> accounts for (a,b) and (b,a)
                    temp->top() -> For previous substrings and their combinations with the current substring
                    child->count() -> For any repeat substrings for current node/substring
                */
                while(!temp.empty())
                {
                    sum = sum + 2 * temp.top() * child->count;
                    temp.pop();
                }

                // Similar to above explanation for if block
                lcp.push(child->count*lcp_label.length());
                sum += lcp.top() * child->count;
                *counter += 1;
            }

            // If a branching node is encountered
            if(countChildren(child,&index) > 1)
            {
                int lc = 0; // dummy variable
                queue <int> ss_count; // queue to keep track of substrings (counter) from the child node of the branching node
                lcp_sum(child,&lc,lcp_label,&ss_count); // Recursively calculate LCP for child node

                // This part is experimental, does not work for all testcases
                // Used to calculate the LCP count for substrings between child nodes of the branching node
                if(countChildren(child,&index) == 2)
                {
                    int counter_queue = ss_count.front();
                    ss_count.pop();

                    while(counter_queue--)
                    {
                        sum = sum +  2 * ss_count.front() * lcp_label.length();
                        ss_count.pop();
                    }
                }
                else
                {
                    // Unclear, what happens if children is > 3
                    // Should one take combination of each child node with one another ?
                    while(!ss_count.empty())
                    {
                        sum = sum +  2 * ss_count.front() * lcp_label.length();
                        ss_count.pop();
                    }
                }

                lcp_label = temp->label; // Set LCP label back to Root's Label

                // Empty the stack till counter is 0, so as to restore it's state when it first entered the child node from the branching node
                while(*counter)
                {
                    lcp.pop();
                    *counter -=1;
                }
                continue; // Continue to next child of the branching node
            }
            else if (countChildren(child,&index) == 1)
            {
                // If count of children is 1, then recursively calculate LCP for further child node
                lcp_sum(child,counter,lcp_label,s_count);
            }
            else
            {
                // If count of child nodes is 0, then push the counter to the queue for that node
                s_count->push(*counter);
                // Empty the stack till counter is 0, so as to restore it's state when it first entered the child node from the branching node
                while(*counter)
                {
                    lcp.pop();
                    *counter -=1;
                }
                lcp_label = temp->label; // Set LCP label back to Root's Label

            }
        }
    }
}

/* SEARCHING A TRIE
bool search(Trie *root, string key)
{
    Trie *temp = root;

    for(int i = 0; i < key.length(); i++)
    {
        int index = key[i] - 'a';
        if(!temp->children[index])
            return false;
        temp = temp->children[index];
    }
    return (temp != NULL );//&& temp->isEndOfWord);
}
*/

int main()
{
    int t;
    cin>>t; // Number of testcases

    while(t--)
    {
        string input;
        int len;
        cin>>len>>input; // Get input length and input string

        Trie *root = getNode();

        for(int i = 0; i < len; i++)
            for(int j = i; j < len; j++)
                insert(root, input.substr(i,j - i + 1)); // Insert all possible substrings into Trie for the given input

        /*
          cout<<"DISPLAY"<<endl;
          display(root);
        */

        //LCP COUNT
        int counter = 0;    //dummy variable
        queue <int> q;      //dummy variable
        lcp_sum(root,&counter,"",&q);
        cout<<sum<<endl;

        sum = 0;

        /*
          compress(root);
          cout<<"AFTER COMPRESSION"<<endl;
          display(root);
        */
    }
    return 0;
}

Also, here are some sample test cases (expected outputs),

1. Input : 2 2 ab 3 zzz

   Output : 6 46

2. Input : 3 1 a 5 afhce 8 ahsfeaa

   Output : 1 105 592

3. Input : 2 15 aabbcceeddeeffa 3 bab

   Output : 7100 26

The above implementation fails for testcase 2 and 3 (partial output). Please suggest a way to solve this. Any other approach to this problem is also fine.

like image 432
Saurabh P Bhandari Avatar asked Feb 04 '23 17:02

Saurabh P Bhandari


2 Answers

Your intuition is going into right direction.

Basically, whenever you see a problem with LCP of substrings, you should think about suffix data structures like suffix trees, suffix arrays, and suffix automata. Suffix trees are arguably the most powerful and the easiest to deal with, and they work perfectly on this problem.

Suffix tree is a trie containing all the suffices of a string, with every non-branching edge chain compressed into a single long edge. The problem with an ordinary trie with all suffices is that it has O(N^2) nodes, so it takes O(N^2) memory. Given that you can precompute LCP of all pairs of suffices in O(N^2) time and space with a trivial dynamic programming, suffix trees are no good without compression. The compressed trie takes O(N) memory, but it is still useless if you build it with O(N^2) algorithm (as you do in your code). You should use Ukkonen's algorithm to construct suffix tree directly in compressed form in O(N) time. Learning and implementing this algorithm is no easy feat, maybe you will find web visualization helpful. As a last minor note, I'll assume for simplicity that a sentinel character (e.g. dollar $) is added to the end of the string, to ensure that all leaves are explicit nodes in the suffix tree.

Note that:

  1. Every suffix of the string is represented as a path from the root to a leaf in the tree (recall about sentinel). This is 1-1 correspondence.
  2. Every substring of the string is represented as a path from the root to a node in the tree (including implicit nodes "inside" long edges) and vice versa. Moreover, all substrings of equal value map into the same path. In order to learn how many equal substrings map into a particular root-node path, count how many leaves are there in the subtree below the node.
  3. In order to find LCP of two substrings, find their corresponding root-node paths, and take LCA of the nodes. LCP is then the depth of the LCA vertex. Of course, it would be a physical vertex, with several edges going down from it.

Here is the main idea. Consider all pairs of substrings, and classify them into groups with same LCA vertex. In other words, let's compute A[v] := the number of pairs of substrings with LCA vertex being exactly v. If you compute this number for every vertex v, then all that remains to solve the problem is: multiply every number by the depth of the node and get the sum. Also, the array A[*] takes only O(N) space, which means that we haven't yet lost the chance to solve the whole problem in linear time.

Recall that every substring is a root-node path. Consider two nodes (representing two arbitrary substrings) and a vertex v. Let's call the subtree with the root at vertex v a "v-subtree". Then:

  • If both nodes are within v-subtree, then their LCA is also within v-subtree.
  • Otherwise, their LCA is outside of v-subtree, so it works both ways.

Let's introduce another quantity B[v] := the number of pairs of substrings with LCA vertex being within v-subtree. The statement just above shows an efficient way to compute B[v]: it is simply the square of the number of nodes within v-subtree, because every pair of nodes in it fits the criterion. However, multiplicity should be taken into account here, so every node must be counted as many times as there are substrings corresponding to it.

Here are the formulas:

    B[v] = Q[v]^2
    Q[v] = sum_s( Q[s] + M[s] * len(vs) )    for s in sons(v)
    M[v] = sum_s( M[s] )                     for s in sons(v)

With M[v] being multiplicity of the vertex (i.e. how many leaves are present in v-subtree), and Q[v] being the number of nodes in the v-subtree with multiplicity taken into account. Of course, you can deduce the base case for the leaves yourself. Using these formulas, you can compute M[*], Q[*], B[*] during one traversal of the tree in O(N) time.

It only remains to compute A[*] array using the B[*] array. It can be done in O(N) by simple exclusion formula:

A[v] = B[v] - sum_s( B[s] )           for s in sons(v)

If you implement all of this, you will be able to solve the whole problem in perfect O(N) time and space. Or better to say: O(N C) time and space, where C is the size of the alphabet.

like image 169
stgatilov Avatar answered Feb 06 '23 07:02

stgatilov


For solving the problem proceed as shown below.

If you look at the picture,Trie for abc I have made a trie for all substrings of abc.

Since, all the substrings are added, every node in the trie has endOfWord true.

Now start traversing the tree with me in a DFS fashion:

  1. sum = 0, stack = {empty}

  2. We encounter a first. Now for L(A,B) a can form 1 pair with itself. Therefore do sum=sum+length and sum becomes 1 now. Now push length i.e 1 in stack. stack = {1}

  3. Move to b now. The substring is now ab. ab like a can form 1 pair with itself. Therefore do sum=sum+length and sum becomes 3 now. Copy the stack contents to stack2. We get 1 as stack top . This means ab and a have LCP 1. But they can form L(a,ab) and L(ab,a). So add sum = sum + 2 * stack.top() . Sum becomes 3+2 = 5. Now copy back stack2 into stack and push length i.e 2. stack becomes {2,1}.

  4. Move to c. Substring is abc. It will form 1 pair with itself, so add 3. Sum becomes 5+3 = 8. Copy stack to stack2. At top we have 2. abc and ab will give LCP 2 and they will form 2 pairs. So sum = sum + 2*2. Sum becomes 12. Pop out 2. Stack has 1 now. abc and a have LCP 1 and can form 2 pairs. So, sum becomes 12+2 = 14. Copy back stack2 into stack and push length i.e 3 into stack.

  5. We reached end of trie. Clear the stack and start from b at length 1 and continue as above. Sum becomes 14+1 = 15 here

  6. We reach c. Substring is bc here. Sum will become 15 + 2 + 2*1(top) = 19.

  7. We reached end of trie. Start from c at length 1. Sum = 19+1 = 20 now.

Time complexity: O(N^3). As it takes O(N^2) to generate substrings and O(N) time to insert them in trie. Node creation is constant time. As all substrings are not of length N, so it will take less than N^3 but T.C. will be O(N^3).

I have tested this method and it gives correct output for words with distinct characters only.

For words that allow repeat of characters it fails. In order to solve for words that allow character repeats, you will need to store the information about the number of times words occur at position A and B for L(A,B). In stack we will need to push pair of length and B_count. Then, you can find sum of LCP using length(in stack)*B_count(in stack)*A_count of current substring. I do not know any method to find A, B counts without using 4 loops.

See the below images for word abb

image 1 image 2 image 3

That's all. Thank you.

like image 21
Adarsh Anurag Avatar answered Feb 06 '23 07:02

Adarsh Anurag