01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
|
// Exercise 2.2.19 (Solution published at http://algs4.cs.princeton.edu/)
package algs22;
import stdlib.*;
/* ***********************************************************************
* Compilation: javac Inversions.java
* Execution: java Inversions N
*
* Generate N pseudo-random numbers between 0 and 1 and count
* the number of inversions in O(N log N) time.
*
*************************************************************************/
public class XInversions {
// merge and count
private static <T extends Comparable<? super T>> int merge(T[] a, T[] aux, int lo, int mid, int hi) {
int inversions = 0;
// copy to aux[]
for (int k = lo; k <= hi; k++) {
aux[k] = a[k];
}
// merge back to a[]
int i = lo, j = mid+1;
for (int k = lo; k <= hi; k++) {
if (i > mid) a[k] = aux[j++];
else if (j > hi) a[k] = aux[i++];
else if (less(aux[j], aux[i])) { a[k] = aux[j++]; inversions += (mid - i + 1); }
else a[k] = aux[i++];
}
return inversions;
}
// return the number of inversions in the subarray b[lo..hi]
// side effect b[lo..hi] is rearranged in ascending order
private static <T extends Comparable<? super T>> int count(T[] a, T[] b, T[] aux, int lo, int hi) {
int inversions = 0;
if (hi <= lo) return 0;
int mid = lo + (hi - lo) / 2;
inversions += count(a, b, aux, lo, mid);
inversions += count(a, b, aux, mid+1, hi);
inversions += merge(b, aux, lo, mid, hi);
assert inversions == brute(a, lo, hi);
return inversions;
}
// count number of inversions in the array a[] - do not overwrite a[]
@SuppressWarnings("unchecked")
public static <T extends Comparable<? super T>> int count(T[] a) {
T[] b = (T[]) new Comparable[a.length];
T[] aux = (T[]) new Comparable[a.length];
for (int i = 0; i < a.length; i++) b[i] = a[i];
int inversions = count(a, b, aux, 0, a.length - 1);
return inversions;
}
// is v < w ?
private static <T extends Comparable<? super T>> boolean less(T v, T w) {
return (v.compareTo(w) < 0);
}
// count number of inversions in a[lo..hi] via brute force (for debugging only)
private static <T extends Comparable<? super T>> int brute(T[] a, int lo, int hi) {
int inversions = 0;
for (int i = lo; i <= hi; i++)
for (int j = i + 1; j <= hi; j++)
if (less(a[j], a[i])) inversions++;
return inversions;
}
// generate N real numbers between 0 and 1, and mergesort them
public static void main(String[] args) {
args = new String[] { "20" };
int N = Integer.parseInt(args[0]);
Double[] a = new Double[N];
for (int i = 0; i < N; i++)
a[i] = Math.random();
StdOut.println(brute(a, 0, N-1));
StdOut.println(count(a));
for (int i = 0; i < N; i++)
StdOut.println(a[i]);
}
}
|