3462. Maximum Sum With at Most K Elements
You are given a 2D integer matrix grid
of size n x m
, an integer array limits
of length n
, and an integer k
. The task is to find the maximum sum of at most k
elements from the matrix grid
such that:
- The number of elements taken from the
ith
row ofgrid
does not exceedlimits[i]
.
Return the maximum sum.
Example 1:
Input: grid = [[1,2],[3,4]], limits = [1,2], k = 2
Output: 7
Explanation:
- From the second row, we can take at most 2 elements. The elements taken are 4 and 3.
- The maximum possible sum of at most 2 selected elements is
4 + 3 = 7
.
Example 2:
Input: grid = [[5,3,7],[8,2,6]], limits = [2,2], k = 3
Output: 21
Explanation:
- From the first row, we can take at most 2 elements. The element taken is 7.
- From the second row, we can take at most 2 elements. The elements taken are 8 and 6.
- The maximum possible sum of at most 3 selected elements is
7 + 8 + 6 = 21
.
Constraints:
n == grid.length == limits.length
m == grid[i].length
1 <= n, m <= 500
0 <= grid[i][j] <= 105
0 <= limits[i] <= m
0 <= k <= min(n * m, sum(limits))
Solution:
class Solution {
public long maxSum(int[][] grid, int[] limits, int k) {
int n = grid.length, m = grid[0].length;
PriorityQueue<int[]> maxHeap = new PriorityQueue<>((a, b) -> Integer.compare(b[0], a[0]));
// Sort each row in descending order
for (int i = 0; i < n; i++) {
java.util.Arrays.sort(grid[i]);
reverse(grid[i]); // Reverse to get descending order
}
// Push the first 'limits[i]' elements of each row into the heap
for (int i = 0; i < n; i++) {
if (limits[i] > 0) {
maxHeap.offer(new int[]{grid[i][0], i, 0}); // {value, row, column}
}
}
long maxSum = 0;
while (k > 0 && !maxHeap.isEmpty()) {
int[] top = maxHeap.poll();
int value = top[0], row = top[1], col = top[2];
maxSum += value;
k--;
// If there are more elements in the row within limits, push the next element
if (col + 1 < limits[row]) {
maxHeap.offer(new int[]{grid[row][col + 1], row, col + 1});
}
}
return maxSum;
}
private void reverse(int[] arr) {
int left = 0, right = arr.length - 1;
while (left < right) {
int temp = arr[left];
arr[left] = arr[right];
arr[right] = temp;
left++;
right--;
}
}
}