Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I improve the efficiency and/or performance of my relatively simple Java counting method?

I'm building a classifier which has to read through a lot of textdocuments, but I found out that my countWordFrequenties method gets slower the more documents it has processed. This method underneath takes 60ms (on my PC), while reading, normalizing, tokenizing, updating my vocabulary and equalizing of different lists of integers only takes 3-5ms in total (on my PC). My countWordFrequencies method is as follows:

public List<Integer> countWordFrequencies(String[] tokens) 
{
    List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
    int counter = 0;

    for (int i = 0; i < vocabulary.size(); i++) 
    {
        for (int j = 0; j < tokens.length; j++)
            if (tokens[j].equals(vocabulary.get(i)))
                counter++;

        wordFreqs.add(i, counter);
        counter = 0;
    }

    return wordFreqs;
}

What is the best way for me to speed this process up? What is the problem of this method?

This is my entire Class, there is another Class Category, is it a good idea to post this also here or don't you guys need it?

public class BayesianClassifier 
{
    private Map<String,Integer>  vocabularyWordFrequencies;
    private List<String> vocabulary;
    private List<Category> categories;
    private List<Integer> wordFrequencies;
    private int trainTextAmount;
    private int testTextAmount;
    private GUI gui;

    public BayesianClassifier() 
    {
        this.vocabulary = new ArrayList<>();
        this.categories = new ArrayList<>();
        this.wordFrequencies = new ArrayList<>();
        this.trainTextAmount = 0;
        this.gui = new GUI(this);
        this.testTextAmount = 0;
    }

    public List<Category> getCategories() 
    {
        return categories;
    }

    public List<String> getVocabulary() 
    {
        return this.vocabulary;
    }

    public List<Integer> getWordFrequencies() 
    {
        return  wordFrequencies;
    }

    public int getTextAmount() 
    {
        return testTextAmount + trainTextAmount;
    }

    public void updateWordFrequency(int index, Integer frequency)
    {
        equalizeIntList(wordFrequencies);
        this.wordFrequencies.set(index, wordFrequencies.get(index) + frequency);
    }

    public String readText(String path) 
    {
        BufferedReader br;
        String result = "";

        try 
        {
            br = new BufferedReader(new FileReader(path));

            StringBuilder sb = new StringBuilder();
            String line = br.readLine();

            while (line != null) 
            {
                sb.append(line);
                sb.append("\n");
                line = br.readLine();
            }

            result = sb.toString();
            br.close();
        } 
        catch (IOException e) 
        {
            e.printStackTrace();
        }

        return result;
    }

    public String normalizeText(String text) 
    {
        String fstNormalized = Normalizer.normalize(text, Normalizer.Form.NFD);

        fstNormalized = fstNormalized.replaceAll("[^\\p{ASCII}]","");
        fstNormalized = fstNormalized.toLowerCase();
        fstNormalized = fstNormalized.replace("\n","");
        fstNormalized = fstNormalized.replaceAll("[0-9]","");
        fstNormalized = fstNormalized.replaceAll("[/()!?;:,.%-]","");
        fstNormalized = fstNormalized.trim().replaceAll(" +", " ");

        return fstNormalized;
    }

    public String[] handleText(String path) 
    {
        String text = readText(path);
        String normalizedText = normalizeText(text);

        return tokenizeText(normalizedText);
    }

    public void createCategory(String name, BayesianClassifier bc) 
    {
        Category newCategory = new Category(name, bc);

        categories.add(newCategory);
    }

    public List<String> updateVocabulary(String[] tokens) 
    {
        for (int i = 0; i < tokens.length; i++)
            if (!vocabulary.contains(tokens[i]))
                vocabulary.add(tokens[i]);

        return vocabulary;
    }

    public List<Integer> countWordFrequencies(String[] tokens)
    {
        List<Integer> wordFreqs = new ArrayList<>(vocabulary.size());
        int counter = 0;

        for (int i = 0; i < vocabulary.size(); i++) 
        {
            for (int j = 0; j < tokens.length; j++)
                if (tokens[j].equals(vocabulary.get(i)))
                    counter++;

            wordFreqs.add(i, counter);
            counter = 0;
        }

        return wordFreqs;
    }

    public String[] tokenizeText(String normalizedText) 
    {
        return normalizedText.split(" ");
    }

    public void handleTrainDirectory(String folderPath, Category category) 
    {
        File folder = new File(folderPath);
        File[] listOfFiles = folder.listFiles();

        if (listOfFiles != null) 
        {
            for (File file : listOfFiles) 
            {
                if (file.isFile()) 
                {
                    handleTrainText(file.getPath(), category);
                }
            }
        } 
        else 
        {
            System.out.println("There are no files in the given folder" + " " + folderPath.toString());
        }
    }

    public void handleTrainText(String path, Category category) 
    {
        long startTime = System.currentTimeMillis();

        trainTextAmount++;

        String[] text = handleText(path);

        updateVocabulary(text);
        equalizeAllLists();

        List<Integer> wordFrequencies = countWordFrequencies(text);
        long finishTime = System.currentTimeMillis();

        System.out.println("That took 1: " + (finishTime-startTime)+ " ms");

        long startTime2 = System.currentTimeMillis();

        category.update(wordFrequencies);
        updatePriors();

        long finishTime2 = System.currentTimeMillis();

        System.out.println("That took 2: " + (finishTime2-startTime2)+ " ms");
    }

    public void handleTestText(String path) 
    {
        testTextAmount++;

        String[] text = handleText(path);
        List<Integer> wordFrequencies = countWordFrequencies(text);
        Category category = guessCategory(wordFrequencies);
        boolean correct = gui.askFeedback(path, category);

        if (correct) 
        {
            category.update(wordFrequencies);
            updatePriors();
            System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
        } 
        else 
        {
            Category correctCategory = gui.askCategory();
            correctCategory.update(wordFrequencies);
            updatePriors();
            System.out.println("Kijk eens aan! De tekst is succesvol verwerkt.");
        }
    }

    public void updatePriors()
    {
        for (Category category : categories)
        {
            category.updatePrior();
        }
    }

    public Category guessCategory(List<Integer> wordFrequencies) 
    {
        List<Double> chances = new ArrayList<>();

        for (int i = 0; i < categories.size(); i++)
        {
            double chance = categories.get(i).getPrior();

            System.out.println("The prior is:" + chance);

            for(int j = 0; j < wordFrequencies.size(); j++)
            {
                chance = chance * categories.get(i).getWordProbabilities().get(j);
            }

            chances.add(chance);
        }

        double max = getMaxValue(chances);
        int index = chances.indexOf(max);

        System.out.println(max);
        System.out.println(index);
        return categories.get(index);
    }

    public double getMaxValue(List<Double> values)
    {
        Double max = 0.0;

        for (Double dubbel : values)
        {
            if(dubbel > max)
            {
                max = dubbel;
            }
        }

        return max;
    }

    public void equalizeAllLists()
    {
        for(Category category : categories)
        {
            if (category.getWordFrequencies().size() < vocabulary.size())
            {
                category.setWordFrequencies(equalizeIntList(category.getWordFrequencies()));
            }
        }

        for(Category category : categories)
        {
            if (category.getWordProbabilities().size() < vocabulary.size())
            {
                category.setWordProbabilities(equalizeDoubleList(category.getWordProbabilities()));
            }
        }
    }

    public List<Integer> equalizeIntList(List<Integer> list)
    {
        while (list.size() < vocabulary.size())
        {
            list.add(0);
        }

        return list;
    }

    public List<Double> equalizeDoubleList(List<Double> list)
    {
        while (list.size() < vocabulary.size())
        {
            list.add(0.0);
        }

        return list;
    }

    public void selectFeatures()
    {
        for(int i = 0; i < wordFrequencies.size(); i++)
        {
            if(wordFrequencies.get(i) < 2)
            {
                vocabulary.remove(i);
                wordFrequencies.remove(i);

                for(Category category : categories)
                {
                    category.removeFrequency(i);
                }
            }
        }
    }
}
like image 391
Breus Avatar asked Dec 28 '15 20:12

Breus


4 Answers

Your method has O(n*m) run time ( n being the vocabulary size and m the token size). With hashing this could be reduced to O(m) which is clearly better.

for (String token: tokens) {
  if(!map.containsKey(token)){
      map.put(token,0);
  }
  map.put(token,map.get(token)+1);
}
like image 110
Sleiman Jneidi Avatar answered Nov 18 '22 20:11

Sleiman Jneidi


Using a Map should dramatically increase performance, as Sleiman Jneidi suggested in his answer. This can be done, however, much more elegantly with Java 8's streaming APIs:

Map<String, Long> frequencies = 
    Arrays.stream(tokens)
          .collect(Collectors.groupingBy(Function.identity(), 
                                         Collectors.counting()));
like image 39
Mureinik Avatar answered Nov 18 '22 19:11

Mureinik


Instead of using a list for the vocabulary, and another one for the frequencies, I'd use a Map that will store word->frequency. That way you can avoid the double loop which in my mind is what kills your performance.

public Map<String,Integer> countWordFrequencies(String[] tokens) {
    // vocabulary is Map<String,Integer> initialized with all words as keys and 0 as value
    for (String word: tokens)
      if (vocabulary.containsKey(word)) {
        vocabulary.put(word, vocabulary.get(word)+1);
      }
    return vocabulary;
}
like image 35
Nir Levy Avatar answered Nov 18 '22 20:11

Nir Levy


If you don't want to use Java 8 stuff you can try to use MultiSet from guava

like image 2
v.ladynev Avatar answered Nov 18 '22 19:11

v.ladynev