Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient tree implementation in MATLAB

Tags:

Tree class in MATLAB

I am implementing a tree data structure in MATLAB. Adding new child nodes to the tree, assigning and updating data values related to the nodes are typical operations that I expect to execute. Each node has the same type of data associated with it. Removing nodes is not necessary for me. So far, I've decided on a class implementation inheriting from the handle class to be able to pass references to nodes around to functions that will modify the tree.

Edit: December 2nd

First of all, thanks for all the suggestions in the comments and answers so far. They have already helped me to improve my tree class.

Someone suggested trying digraph introduced in R2015b. I have yet to explore this, but seeing as it does not work as a reference parameter similarly to a class inheriting from handle, I am a bit sceptical how it will work in my application. It is also at this point not yet clear to me how easy it will be to work with it using custom data for nodes and edges.

Edit: (Dec 3rd) Further information on the main application: MCTS

Initially, I assumed the details of the main application would only be of marginal interest, but since reading the comments and the answer by @FirefoxMetzger, I realise that it has important implications.

I am implementing a type of Monte Carlo tree search algorithm. A search tree is explored and expanded in an iterative manner. Wikipedia offers a nice graphical overview of the process: Monte Carlo tree search

In my application I perform a large number of search iterations. On every search iteration, I traverse the current tree starting from the root until a leaf node, then expand the tree by adding new nodes, and repeat. As the method is based on random sampling, at the start of each iteration I do not know which leaf node I will finish at on each iteration. Instead, this is determined jointly by the data of nodes currently in the tree, and the outcomes of random samples. Whatever nodes I visit during a single iteration have their data updated.

Example: I am at node n which has a few children. I need to access data in each of the children and draw a random sample that determines which of the children I move to next in the search. This is repeated until a leaf node is reached. Practically I am doing this by calling a search function on the root that will decide which child to expand next, call search on that node recursively, and so on, finally returning a value once a leaf node is reached. This value is used while returning from the recursive functions to update the data of the nodes visited during the search iteration.

The tree may be quite unbalanced such that some branches are very long chains of nodes, while others terminate quickly after the root level and are not expanded further.

Current implementation

Below is an example of my current implementation, with example of a few of the member functions for adding nodes, querying the depth or number of nodes in the tree, and so on.

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

However, performance at times is slow, with operations on the tree taking up most of my computation time. What specific ways would there be to make the implementation more efficient? It would even be possible to change the implementation to something else than the handle inheritance type if there are performance gains.

Profiling results with current implementation

As adding new nodes to the tree is the most typical operation (along with updating the data of a node), I did some profiling on that. I ran the profiler on the following benchmarking code with Nd=6, Ns=10.

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

Results from the profiler: Profiler results

R2015b performance


It has been discovered that R2015b improves the performance of MATLAB's OOP features. I redid the above benchmark and indeed observed an improvement in performance:

R2015b profiler result

So this is already good news, although further improvements are of course accepted ;)

Reserving memory differently

It was also suggested in the comments to use

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

to reserve more memory rather than the current approach with repmat. This improved performance slightly. I should note that my benchmark code is for dummy data, and since the actual data is more complicated this is likely to help. Thanks! Profiler results below:

zeeMonkeez memory reserve style

Questions on even further increasing performance

  1. Perhaps there is an alternative way to maintain memory for the tree that is more efficient? Sadly, I typically don't know ahead of time how many nodes there will be in the tree.
  2. Adding new nodes and modifying the data of existing nodes are the most typical operations I do on the tree. As of now, they actually take up most of the processing time of my main application. Any improvements on these functions would be most welcome.

Just as a final note, I would ideally like to keep the implementation as pure MATLAB. However, options such as MEX or using some of the integrated Java functionalities may be acceptable.