The K-Median Problem#

The k-median problem is the problem of clustering data points into k clusters, aiming to minimize the sum of distances between points belonging to a particular cluster and the data point that is the center of the cluster. This can be considered a variant of k-means clustering. For k-means clustering, we determine the mean value of each cluster, whereas for k-median we use the median value. This problem is known as NP-hard. We describe how to implement the mathematical model of k-median problem with JijModeling and solve it with JijZeptSolver.

Mathematical Model#

Let us consider a mathematical model for k-median problem.

Decision variables#

We denote \(x_{i, j}\) to be a binary variable which is 1 if \(i\)-th data point belongs to the \(j\)-th median data point and 0 otherwise. We also use a binary variable \(y_j\) which is 1 if \(j\)-th data point is the median and 0 otherwise.

\[\begin{split} x_{i,j} = \begin{cases} 1,~\text{Node $i$ is covered by median $j$}\\ 0,~\text{Otherwise} \end{cases} \end{split}\]
\[\begin{split} y_j = \begin{cases} 1,~\text{Node $j$ is a median.}\\ 0,~\text{Otherwise} \end{cases} \end{split}\]

Mathematical Model#

Our goal is to find a solution that minimizes the sum of the distances between \(i\)-th data point and \(j\)-th median point. We also set three constraints:

  1. A data point must belong to a single median data point,

  2. The number of median points is \(k\),

  3. The data points must belong to a median point.

These can be expressed in a mathematical model as follows.

\[\begin{split} \begin{align} \min_x &\sum_{i}\sum_j d_{i,j}x_{i,j} \notag\\ \mathrm{s.t.}~&\sum_{j} x_{i,j} = 1,~\forall i \notag\\ &\sum_j y_j = k \notag\\ &x_{i,j} \leq y_j, ~\forall i, j \notag\\ &x_{i,j} \in \{0, 1\} ~\forall i, j \notag\\ &y_j \in \{0, 1\}~\forall j \tag{1} \end{align} \end{split}\]

Modeling by JijModeling#

Here, we show an implementation using JijModeling. We first define variables for the mathematical model described above.

import jijmodeling as jm

problem = jm.Problem("k-median")

N = problem.Natural("N")
d = problem.Float("d", shape=(N, N))
k = problem.Natural("k")
x = problem.BinaryVar("x", shape=(N, N))
y = problem.BinaryVar("y", shape=(N,))

N represents the number of data points, d is a two-dimensional array indicating the distance between each data point, and k defines the number of median points to be determined. In addition, we define the binary variables x and y required to solve this optimization problem.

Then, we implement equations (1).

problem += jm.sum(jm.product(N, N), lambda i, j: d[i, j] * x[i, j])
problem += problem.Constraint("onehot", lambda i: x[i, :].sum() == 1, domain=N)
problem += problem.Constraint("k-median", y.sum() == k)
problem += problem.Constraint("cover", lambda i, j: x[i, j] <= y[j], domain=(N, N))

With problem.Constraint("onehot", lambda i: x[i, :].sum() == 1, domain=N), we insert as a constraint that \(\sum_j x_{i, j} = 1\) for all \(i\). problem.Constraint("k-median", y.sum() == k) represents \(\sum_j y_j = k\). problem.Constraint("cover", lambda i, j: x[i, j] <= y[j], domain=(N, N)) requires that \(x_{i, j} \leq y_j\) must be for all \(i, j\).

We can check the implementation of the mathematical model on Jupyter Notebook.

problem
\[\begin{split}\begin{array}{rl} \text{Problem}\colon &\text{k-median}\\\displaystyle \min &\displaystyle \sum _{i=0}^{N-1}{\sum _{j=0}^{N-1}{{d}_{i,j}\cdot {x}_{i,j}}}\\&\\\text{s.t.}&\\&\begin{aligned} \text{cover}&\quad \displaystyle {x}_{i,j}\leq {y}_{j}\quad \forall \left(i,j\right)\;\text{s.t.}\;i\in \left\{0,\ldots ,N-1\right\},j\in \left\{0,\ldots ,N-1\right\}\\\text{k-median}&\quad \displaystyle \sum _{\vec{\imath }}{{{\left(y\right)}}_{\vec{\imath }}}=k\\\text{onehot}&\quad \displaystyle \sum _{{\vec{\imath }}_{1}}{{{\left({x}_{i,\left(\colon \right)}\right)}}_{{\vec{\imath }}_{1}}}=1\quad \forall i\;\text{s.t.}\;i\in \left\{0,\ldots ,N-1\right\}\end{aligned} \\&\\\text{where}&\\&\text{Decision Variables:}\\&\qquad \begin{alignedat}{2}x&\in \mathop{\mathrm{Array}}\left[N\times N;\left\{0, 1\right\}\right]&\quad &2\text{-dim binary variable}\\y&\in \mathop{\mathrm{Array}}\left[N;\left\{0, 1\right\}\right]&\quad &1\text{-dim binary variable}\\\end{alignedat}\\&\\&\text{Placeholders:}\\&\qquad \begin{alignedat}{2}d&\in \mathop{\mathrm{Array}}\left[N\times N;\mathbb{R}\right]&\quad &2\text{-dimensional array of placeholders with elements in }\mathbb{R}\\k&\in \mathbb{N}&\quad &\text{A scalar placeholder in }\mathbb{N}\\N&\in \mathbb{N}&\quad &\text{A scalar placeholder in }\mathbb{N}\\\end{alignedat}\end{array} \end{split}\]

Prepare instance#

We prepare and visualize data points.

import matplotlib.pyplot as plt
import numpy as np

inst_N = 30
X, Y = np.random.uniform(0, 1, (2, inst_N))

plt.plot(X, Y, "o")
[<matplotlib.lines.Line2D at 0x10662fb10>]
../_images/77f21310506632ed4216633a3641d707e1fad89d431a3ea234bc02314768c3a0.png

We compute the distance between each data point.

XX, XX_T = np.meshgrid(X, X)
YY, YY_T = np.meshgrid(Y, Y)
inst_d = np.sqrt((XX - XX_T)**2 + (YY - YY_T)**2)
inst_k = 4

instance_data = {"N": inst_N, "d": inst_d, "k": inst_k}

Solve with JijZeptSolver#

We solve the problem using jijzept_solver.

import jijzept_solver

instance = problem.eval(instance_data)
solution = jijzept_solver.solve(instance, solve_limit_sec=1.0)

Visualize the solution#

We visualize the solution obtained/

df = solution.decision_variables_df
y_indices = np.ravel(df[(df["name"] == "y") & (df["value"] == 1.0)]["subscripts"].to_list())
x_indices = df[(df["name"] == "x") & (df["value"] == 1.0)]["subscripts"].to_list()

median_X, median_Y = X[y_indices], Y[y_indices]
d_from_m = np.sqrt((X[:, None]-X[y_indices])**2 + (Y[:, None]-Y[y_indices])**2)
cover_median = y_indices[np.argmin(d_from_m, axis=1)]
plt.plot(X, Y, "o")
plt.plot(X[y_indices], Y[y_indices], "o", markersize=10)
plt.plot(np.column_stack([X, X[cover_median]]).T, np.column_stack([Y, Y[cover_median]]).T, c="gray")
plt.show()
../_images/17d8b098b461e70f1475c6a4af3c34f9160ea89949388cf7c427545612d58805.png

This figure shows how they are in clusters. Orange and blue points show the median and other data points, respectively. The gray line connects the median and the data points belonging to that cluster.