Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Finding an optimal selection in a 2D matrix with given constrains

Problem statement

Given a m x n matrix where m <= n you have to select entries so that their sum is maximal.

  • However you can only select one entry per row and at most one per column.
  • The performance is also a huge factor which means its ok to find selections that are not optimal in oder to reduce complexity (as long as its better than selecting random entries)

Example

  • Valid selections:

    valid valid empty col

  • Invalid selections: (one entry per row and at most one per column)

    invalid empty row and row dublicate invalid dublicate col

My Approaches

  1. Select best of k random permutations

    A = createRandomMatrix(m,n)
    selections = list()
    
    for try in range(k):
      cols = createRandomIndexPermutation(m) # with no dublicates
      for row in range(m):
        sum += A[row, cols[row]]
        selections.append(sum)
    
    result = max(selections)
    

    This appoach performs poorly when n is significantly larger than m

  2. Best possible (not yet taken) column per row

    A = createRandomMatrix(m,n)
    takenCols = set()
    
    result = 0
    for row in range(m):
      col = getMaxColPossible(row, takenCols, A)
      result += A[row, col]
      takenCols.add(col)
    

    This approach always values the rows (or columns) higher that were discovered first which could lead to worse than average results

like image 751
RobinW Avatar asked Jan 29 '26 09:01

RobinW


1 Answers

This sounds exactly like the rectangular linear assignment problem (RLAP). This problem can be efficiently (in terms of asymptotic complexity; somewhat around cubic time) solved (to a global-optimum) and a lot of software is available.

The basic approaches are LAP + dummy-vars, LAP-modifications or more general algorithms like network-flows (min-cost max-flow).

You can start with (pdf):

Bijsterbosch, J., and A. Volgenant. "Solving the Rectangular assignment problem and applications." Annals of Operations Research 181.1 (2010): 443-462.

Small python-example using python's common scientific-stack:

Edit: as mentioned in the comments, negating the cost-matrix (which i did, motivated by the LP-description) is not what's done in the Munkres/Hungarian-method literature. The strategy is to build a profit-matrix from the cost-matrix, which is now reflected in the example. This approach will lead to a non-negative cost-matrix (sometimes assumes; if it's important, depends on the implementation). More information is available in this question.

Code

import numpy as np
import scipy.optimize as sopt    # RLAP solver
import matplotlib.pyplot as plt  # visualizatiion
import seaborn as sns            # """
np.random.seed(1)

# Example data from
# https://matplotlib.org/gallery/images_contours_and_fields/image_annotated_heatmap.html
# removed a row; will be shuffled to make it more interesting!
harvest = np.array([[0.8, 2.4, 2.5, 3.9, 0.0, 4.0, 0.0],
                    [2.4, 0.0, 4.0, 1.0, 2.7, 0.0, 0.0],
                    [1.1, 2.4, 0.8, 4.3, 1.9, 4.4, 0.0],
                    [0.6, 0.0, 0.3, 0.0, 3.1, 0.0, 0.0],
                    [0.7, 1.7, 0.6, 2.6, 2.2, 6.2, 0.0],
                    [1.3, 1.2, 0.0, 0.0, 0.0, 3.2, 5.1]],)
harvest = harvest[:, np.random.permutation(harvest.shape[1])]

# scipy: linear_sum_assignment -> able to take rectangular-problem!
# assumption: minimize -> cost-matrix to profit-matrix:
#                         remove original cost from maximum-costs
#                         Kuhn, Harold W.:
#                         "Variants of the Hungarian method for assignment problems."
max_cost = np.amax(harvest)
harvest_profit = max_cost - harvest

row_ind, col_ind = sopt.linear_sum_assignment(harvest_profit)
sol_map = np.zeros(harvest.shape, dtype=bool)
sol_map[row_ind, col_ind] = True

# Visualize
f, ax = plt.subplots(2, figsize=(9, 6))
sns.heatmap(harvest, annot=True, linewidths=.5, ax=ax[0], cbar=False,
            linecolor='black', cmap="YlGnBu")
sns.heatmap(harvest, annot=True, mask=~sol_map, linewidths=.5, ax=ax[1],
            linecolor='black', cbar=False, cmap="YlGnBu")
plt.tight_layout()
plt.show()

Output

enter image description here

like image 95
sascha Avatar answered Jan 31 '26 01:01

sascha