When m is the amount of features and n is the amount of samples, the python scikit-learn site (http://scikit-learn.org/stable/modules/tree.html) states that the runtime to construct a binary decision tree is mnlog(n).
I understand that the log(n) comes from the average height of the tree after splitting. I understand that at each split, you have to look at each feature (m) and choose the best one to split on. I understand that this is done by calculating a "best metric" (in my case, a gini impurity) for each sample at that node (n). However, to find the best split, doesn't this mean that you would have to look at each possible way to split the samples for each feature? And wouldn't that be something like 2^n-1 * m rather than just mn? Am I thinking about this wrong? Any advice would help. Thank you.
One way to build a decision tree would be, at each point, to do something like this:
The question is how do perform each step. If you have continuous data, a common technique for finding the best possible split would be to sort the data into ascending order along that data point, then consider all possible partition points between those data points and taking the one that minimizes the entropy. This sorting step takes time O(n log n), which dominates the runtime. Since we're doing that for each of the O(m) features, the runtime ends up working out to O(mn log n) total work done per node.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With