Last updated on February 11, 2025 pm
写在前面
本文记录自己实现外部排序所遇到的问题以及所有思考。
实现原理
当数据量过大,内存无法一次性容纳所有数据时,我们就需要用到外部排序。既然内存无法容纳全部数据,那必然要涉及到磁盘的读写问题。
磁盘空间是以“块”为单位进行空间管理的。同样,读和写也是以“块”为单位。数据只有读入内存才能进行修改,修改完还需要写回磁盘。
第一阶段:部分排序阶段
根据内存大小,将待排序的文件拆成多个部分,使得每个部分都是能够存入内存中。然后选择合适的内排序算法(比如快排),将这部分进行排序,并输出到外存临时文件中。这样得到的每个临时文件都是有序排列的,我们将其称之为一个顺段。
第二阶段:归并阶段
对前面的多个“顺段”进行合并,以2路归并为例,每次都将两个连续的顺段合并成一个更大的顺段。但因为内存限制,每次可能只能读入两个顺段的部分内容,所以我们需要一部分一部分读入,在内存里将进行排序,并输出到外存里的文件中,不断重复这个过程,直至两个顺段被完整遍历。这样经过多层的归并之后,最终会得到一个完整的顺序文件。
存在的问题
上述算法的整体时间开销 = 读写外存所需要的时间 + 内部排序所需要的时间 + 内部归并所需要的时间
其中,读写外存所需要的时间占比非常大!!!我们需要尽可能减少IO次数,IO读写次数为
$$
log_kn
$$
n是第一阶段结束后,所形成的顺段个数;k是指第二阶段采用的K路归并。
通过增大k或者减小n,可以提高排序效率。两者都是以空间换时间。
但是,增大k会导致归并时每挑选一个关键字需要对比 k-1 次,内部归并所需要的时间又增加了。这个问题可以使用败者树来解决。
败者树
有了败者树,选出最小元素,只需要对比关键字 log2k 次。
代码实现
首先生成一些测试数据,我这里生成了50000个随机数,写在磁盘上。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
| public static void generateData(String filename) { File file = new File(filename); int numIntegers = 50000; try { if (file.createNewFile()) { try (FileWriter writer = new FileWriter(file)) { for (int i = 0; i < numIntegers; i++) { Random random = new Random(); int randomNumber = random.nextInt(100000); writer.write(randomNumber + "\n"); } } } else { System.out.println("文件已存在: " + file.getAbsolutePath()); } } catch (IOException e) { System.err.println("创建文件时出现错误: " + e.getMessage()); } }
|
由前面的介绍可知,此算法一共有两步。
Step 1:
**Step 2:**k路归并文件,首先获取到每个输入文件的首个值,用这些数据去初始化败者树,然后开始归并,结果写入磁盘。
代码太长不贴了,这里要注意:
- 在使用完之后关闭输入流和输出流,防止资源泄露,一个潜在的安全问题;
- 在合并完成后删除那些不再需要文件,减少不必要的文件对磁盘容量的占用。这里删的是我们程序中间生成的临时文件,不是原本的数据文件,最后程序运行完之后,磁盘上应该只留下原本的数据文件和我们排序完之后的文件,严谨!
败者树的结构体这里放一下。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
| class LoserTree { int k; int[] losers; int[] keys;
LoserTree(int K) { k = K; losers = new int[k]; for (int i = 0; i < k; i++) { losers[i] = -1; } keys = new int[k]; for (int i = 0; i < k; i++) { keys[i] = Integer.MAX_VALUE; } }
void initialize(Integer[] initialKeys) { for (int i = 0; i < k; i++) { keys[i] = initialKeys[i]; }
for (int i = 0; i < k; i++) { adjust(i); } }
int winner() { return losers[0]; }
void update(int index, int newKey) { keys[index] = newKey; adjust(index); }
int getKey(int index) { return keys[index]; }
void adjust(int s) { int parent = (s + k) / 2; int temp = s; while (parent > 0) { if (temp != -1 && (losers[parent] == -1 || keys[temp] > keys[losers[parent]])) { int temp2 = temp; temp = losers[parent]; losers[parent] = temp2; } parent /= 2; } losers[0] = temp; } }
|
结果对比
k为2 用时为969毫秒
k为8 用时为306毫秒
果然k越大,运行越快。
碎碎念
中间一度想要在循环中去动态创建变量,比如命名为 a_0
、a_1
、a_2
这样形式。但是java不能这样去做,Java 是一种静态类型语言,变量名需要在编译时确定,无法在运行时动态生成。
好在我这里提前知道循环的次数,最后将变量创建在了数组中。
reader.readLine() 每次读取一行数据,我先用这个方法取了数据,紧接着又调用这个方法判断是否还有数据,结果导致最后的结果只有一半的数据,还以为是K路合并的代码有问题,排查了好久。。。(因为第一步部分排序的时候只用了一次readLine()方法,k路合并时读数据用了两次)
反思:其实数据正好少了一半,应该想到是取数据的问题。
最后
贴一下全部代码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
| import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.*;
class LoserTree { int k; int[] losers; int[] keys;
LoserTree(int K) { k = K; losers = new int[k]; for (int i = 0; i < k; i++) { losers[i] = -1; } keys = new int[k]; for (int i = 0; i < k; i++) { keys[i] = Integer.MAX_VALUE; } }
void initialize(Integer[] initialKeys) { for (int i = 0; i < k; i++) { keys[i] = initialKeys[i]; }
for (int i = 0; i < k; i++) { adjust(i); } }
int winner() { return losers[0]; }
void update(int index, int newKey) { keys[index] = newKey; adjust(index); }
int getKey(int index) { return keys[index]; }
void adjust(int s) { int parent = (s + k) / 2; int temp = s; while (parent > 0) { if (temp != -1 && (losers[parent] == -1 || keys[temp] > keys[losers[parent]])) { int temp2 = temp; temp = losers[parent]; losers[parent] = temp2; } parent /= 2; } losers[0] = temp; } }
public class Main {
public static List<String> tempFiles = new ArrayList<>();
public static void generateData(String filename) { File file = new File(filename); int numIntegers = 50000; try { if (file.createNewFile()) { try (FileWriter writer = new FileWriter(file)) { for (int i = 0; i < numIntegers; i++) { Random random = new Random(); int randomNumber = random.nextInt(100000); writer.write(randomNumber + "\n"); } } } else { System.out.println("文件已存在: " + file.getAbsolutePath()); } } catch (IOException e) { System.err.println("创建文件时出现错误: " + e.getMessage()); } }
public static void readAndSortChunk(String filename, int chunkSize) { try (BufferedReader reader = new BufferedReader(new FileReader(filename))) { int[] chunk = new int[chunkSize]; int chunkIndex = 0; String line; int index = 0; while ((line = reader.readLine()) != null) { try { int number = Integer.parseInt(line); chunk[index] = number; index++; if (index == chunkSize) { sortAndWriteToTempFile(chunk, chunkIndex++); index = 0; } } catch (NumberFormatException e) { System.err.println("文件中包含非数字内容: " + line); } } if (index > 0) { int[] lastchunk = new int[index]; System.arraycopy(chunk, 0, lastchunk, 0, index); sortAndWriteToTempFile(lastchunk, chunkIndex); } } catch (IOException e) { System.err.println("读取文件时出错: " + e.getMessage()); } }
public static void sortAndWriteToTempFile(int[] array, int chunkIndex) { quickSort(array, 0, array.length - 1); String tempFileName = "D:\\temp_" + ++chunkIndex + ".txt"; File file = new File(tempFileName); try { if (file.createNewFile()) { try (FileWriter writer = new FileWriter(file)) { for (int i = 0; i < array.length; i++) { writer.write(array[i] + "\n"); } } tempFiles.add(tempFileName); } else { System.out.println("文件已存在: " + file.getAbsolutePath()); } } catch (IOException e) { System.err.println("创建文件时出现错误: " + e.getMessage()); } }
public static void quickSort(int[] array, int low, int high) { if(low < high) { int position = partition(array, low, high); quickSort(array, low, position - 1); quickSort(array, position+1, high); } }
public static int partition(int[] array,int low,int high) { int pivot = array[high]; int pointer = low; for(int i = low;i<high;i++) { if(array[i]<=pivot){ int temp = array[i]; array[i] = array[pointer]; array[pointer] = temp; pointer++; } } int temp = array[pointer]; array[pointer] = array[high]; array[high] = temp; return pointer; }
public static void mergeFiles(List<String> filesToMerge, String outputTempFile) { int k = filesToMerge.size(); Integer[] values = new Integer[k]; BufferedReader[] reader = new BufferedReader[k]; for (int i = 0; i < filesToMerge.size(); i++) { String filePath = filesToMerge.get(i); Path path = Paths.get(filePath); try { reader[i] = Files.newBufferedReader(path); String line; if ((line = reader[i].readLine()) != null) { values[i] = Integer.valueOf(line); } } catch (IOException e) { e.printStackTrace(); } }
File file = new File(outputTempFile); FileWriter writer = null; try { file.createNewFile(); writer = new FileWriter(file); } catch (IOException e) { System.err.println("创建文件时出现错误: " + e.getMessage()); }
LoserTree lt = new LoserTree(k); lt.initialize(values); int count = 0; while (true) { count++; int winnerIndex = lt.winner(); int winnerValue = lt.getKey(winnerIndex); if (winnerValue == Integer.MAX_VALUE) { break; }
try { writer.write(winnerValue + "\n"); } catch (IOException e) { throw new RuntimeException(e); }
try { String line = reader[winnerIndex].readLine(); if (line != null) { values[winnerIndex] = Integer.valueOf(line); lt.update(winnerIndex, values[winnerIndex]); } else { lt.update(winnerIndex, Integer.MAX_VALUE); } } catch (IOException e) { e.printStackTrace(); } }
System.out.println("countcount" + count);
for (int i = 0; i < filesToMerge.size(); i++) { try { reader[i].close(); } catch (IOException e) { throw new RuntimeException(e); } }
for (String filedelete : filesToMerge) { File filed = new File(filedelete); filed.delete(); } if (writer != null) { try { writer.close(); } catch (IOException e) { e.printStackTrace(); } } }
public static void externalSort(String filename, String outputFile, int chunkSize, int k) { System.out.println("STEP 1 开始"); readAndSortChunk(filename, chunkSize); System.out.println("STEP 1 完成");
System.out.println("STEP 2 开始"); int chunkIndex = 0; while (tempFiles.size() > 1) { List<String> newTempFiles = new ArrayList<>(); for (int i = 0; i < tempFiles.size(); i += k) { List<String> filesToMerge = new ArrayList<>(); for (int j = 0; j < k && i + j < tempFiles.size(); j++) { filesToMerge.add(tempFiles.get(i + j)); } if (filesToMerge.isEmpty()) break; if (filesToMerge.size() > 1) { String outputTempFile = "D:\\temp_merge_" + ++chunkIndex + ".txt"; mergeFiles(filesToMerge, outputTempFile); newTempFiles.add(outputTempFile); } else { newTempFiles.add(filesToMerge.get(0)); } } tempFiles = newTempFiles; System.out.println("tempFiles:" + tempFiles); } System.out.println("STEP 2 完成");
if (!tempFiles.isEmpty()) { File originalFile = new File(tempFiles.get(0)); File renamedFile = new File(outputFile); if (originalFile.renameTo(renamedFile)) { System.out.println("文件重命名成功,新文件路径为: " + renamedFile.getAbsolutePath()); } else { System.err.println("文件重命名失败,请检查文件是否存在或是否有足够的权限。"); } } }
public static void main(String[] args) { String filename = "D:\\data.txt"; String outputFile = "D:\\out13.txt"; generateData(filename);
int chunkSize = 1000; int k = 2;
long startTime = System.currentTimeMillis(); externalSort(filename, outputFile, chunkSize, k); long endTime = System.currentTimeMillis(); long executionTime = endTime - startTime; System.out.println("k为" + k + " 用时为" + executionTime + "毫秒"); }
}
|