kmeansiou.c 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. //usr/bin/cc -Ofast -lm "${0}" -o "${0%.c}" && ./"${0%.c}" "$@"; s=$?; rm ./"${0%.c}"; exit $s
  2. #include <math.h>
  3. #include <stdio.h>
  4. #include <stdlib.h>
  5. #include <string.h>
  6. #include <time.h>
  7. typedef struct matrix{
  8. int rows, cols;
  9. double **vals;
  10. } matrix;
  11. matrix csv_to_matrix(char *filename, int header);
  12. matrix make_matrix(int rows, int cols);
  13. void zero_matrix(matrix m);
  14. void copy(double *x, double *y, int n);
  15. double dist(double *x, double *y, int n);
  16. int *sample(int n);
  17. int find_int_arg(int argc, char **argv, char *arg, int def);
  18. int find_arg(int argc, char* argv[], char *arg);
  19. int closest_center(double *datum, matrix centers)
  20. {
  21. int j;
  22. int best = 0;
  23. double best_dist = dist(datum, centers.vals[best], centers.cols);
  24. for(j = 0; j < centers.rows; ++j){
  25. double new_dist = dist(datum, centers.vals[j], centers.cols);
  26. if(new_dist < best_dist){
  27. best_dist = new_dist;
  28. best = j;
  29. }
  30. }
  31. return best;
  32. }
  33. double dist_to_closest_center(double *datum, matrix centers)
  34. {
  35. int ci = closest_center(datum, centers);
  36. return dist(datum, centers.vals[ci], centers.cols);
  37. }
  38. int kmeans_expectation(matrix data, int *assignments, matrix centers)
  39. {
  40. int i;
  41. int converged = 1;
  42. for(i = 0; i < data.rows; ++i){
  43. int closest = closest_center(data.vals[i], centers);
  44. if(closest != assignments[i]) converged = 0;
  45. assignments[i] = closest;
  46. }
  47. return converged;
  48. }
  49. void kmeans_maximization(matrix data, int *assignments, matrix centers)
  50. {
  51. int i,j;
  52. int *counts = calloc(centers.rows, sizeof(int));
  53. zero_matrix(centers);
  54. for(i = 0; i < data.rows; ++i){
  55. ++counts[assignments[i]];
  56. for(j = 0; j < data.cols; ++j){
  57. centers.vals[assignments[i]][j] += data.vals[i][j];
  58. }
  59. }
  60. for(i = 0; i < centers.rows; ++i){
  61. if(counts[i]){
  62. for(j = 0; j < centers.cols; ++j){
  63. centers.vals[i][j] /= counts[i];
  64. }
  65. }
  66. }
  67. }
  68. double WCSS(matrix data, int *assignments, matrix centers)
  69. {
  70. int i, j;
  71. double sum = 0;
  72. for(i = 0; i < data.rows; ++i){
  73. int ci = assignments[i];
  74. sum += (1 - dist(data.vals[i], centers.vals[ci], data.cols));
  75. }
  76. return sum / data.rows;
  77. }
  78. typedef struct{
  79. int *assignments;
  80. matrix centers;
  81. } model;
  82. void smart_centers(matrix data, matrix centers) {
  83. int i,j;
  84. copy(data.vals[rand()%data.rows], centers.vals[0], data.cols);
  85. double *weights = calloc(data.rows, sizeof(double));
  86. int clusters = centers.rows;
  87. for (i = 1; i < clusters; ++i) {
  88. double sum = 0;
  89. centers.rows = i;
  90. for (j = 0; j < data.rows; ++j) {
  91. weights[j] = dist_to_closest_center(data.vals[j], centers);
  92. sum += weights[j];
  93. }
  94. double r = sum*((double)rand()/RAND_MAX);
  95. for (j = 0; j < data.rows; ++j) {
  96. r -= weights[j];
  97. if(r <= 0){
  98. copy(data.vals[j], centers.vals[i], data.cols);
  99. break;
  100. }
  101. }
  102. }
  103. free(weights);
  104. }
  105. void random_centers(matrix data, matrix centers){
  106. int i;
  107. int *s = sample(data.rows);
  108. for(i = 0; i < centers.rows; ++i){
  109. copy(data.vals[s[i]], centers.vals[i], data.cols);
  110. }
  111. free(s);
  112. }
  113. model do_kmeans(matrix data, int k)
  114. {
  115. matrix centers = make_matrix(k, data.cols);
  116. int *assignments = calloc(data.rows, sizeof(int));
  117. smart_centers(data, centers);
  118. //random_centers(data, centers);
  119. if(k == 1) kmeans_maximization(data, assignments, centers);
  120. while(!kmeans_expectation(data, assignments, centers)){
  121. kmeans_maximization(data, assignments, centers);
  122. }
  123. model m;
  124. m.assignments = assignments;
  125. m.centers = centers;
  126. return m;
  127. }
  128. int main(int argc, char *argv[])
  129. {
  130. if(argc < 3){
  131. fprintf(stderr, "usage: %s <csv-file> [points/centers/stats]\n", argv[0]);
  132. return 0;
  133. }
  134. int i,j;
  135. srand(time(0));
  136. matrix data = csv_to_matrix(argv[1], 0);
  137. int k = find_int_arg(argc, argv, "-k", 2);
  138. int header = find_arg(argc, argv, "-h");
  139. int count = find_arg(argc, argv, "-c");
  140. if(strcmp(argv[2], "assignments")==0){
  141. model m = do_kmeans(data, k);
  142. int *assignments = m.assignments;
  143. for(i = 0; i < k; ++i){
  144. if(i != 0) printf("-\n");
  145. for(j = 0; j < data.rows; ++j){
  146. if(!(assignments[j] == i)) continue;
  147. printf("%f, %f\n", data.vals[j][0], data.vals[j][1]);
  148. }
  149. }
  150. }else if(strcmp(argv[2], "centers")==0){
  151. model m = do_kmeans(data, k);
  152. printf("WCSS: %f\n", WCSS(data, m.assignments, m.centers));
  153. int *counts = 0;
  154. if(count){
  155. counts = calloc(k, sizeof(int));
  156. for(j = 0; j < data.rows; ++j){
  157. ++counts[m.assignments[j]];
  158. }
  159. }
  160. for(j = 0; j < m.centers.rows; ++j){
  161. if(count) printf("%d, ", counts[j]);
  162. printf("%f, %f\n", m.centers.vals[j][0], m.centers.vals[j][1]);
  163. }
  164. }else if(strcmp(argv[2], "scan")==0){
  165. for(i = 1; i <= k; ++i){
  166. model m = do_kmeans(data, i);
  167. printf("%f\n", WCSS(data, m.assignments, m.centers));
  168. }
  169. }
  170. return 0;
  171. }
  172. // Utility functions
  173. int *sample(int n)
  174. {
  175. int i;
  176. int *s = calloc(n, sizeof(int));
  177. for(i = 0; i < n; ++i) s[i] = i;
  178. for(i = n-1; i >= 0; --i){
  179. int swap = s[i];
  180. int index = rand()%(i+1);
  181. s[i] = s[index];
  182. s[index] = swap;
  183. }
  184. return s;
  185. }
  186. double dist(double *x, double *y, int n)
  187. {
  188. int i;
  189. double mw = (x[0] < y[0]) ? x[0] : y[0];
  190. double mh = (x[1] < y[1]) ? x[1] : y[1];
  191. double inter = mw*mh;
  192. double sum = x[0]*x[1] + y[0]*y[1];
  193. double un = sum - inter;
  194. double iou = inter/un;
  195. return 1-iou;
  196. }
  197. void copy(double *x, double *y, int n)
  198. {
  199. int i;
  200. for(i = 0; i < n; ++i) y[i] = x[i];
  201. }
  202. void error(char *s){
  203. fprintf(stderr, "Error: %s\n", s);
  204. exit(-1);
  205. }
  206. char *fgetl(FILE *fp)
  207. {
  208. if(feof(fp)) return 0;
  209. int size = 512;
  210. char *line = malloc(size*sizeof(char));
  211. if(!fgets(line, size, fp)){
  212. free(line);
  213. return 0;
  214. }
  215. int curr = strlen(line);
  216. while(line[curr-1]!='\n'){
  217. size *= 2;
  218. line = realloc(line, size*sizeof(char));
  219. if(!line) error("Malloc");
  220. fgets(&line[curr], size-curr, fp);
  221. curr = strlen(line);
  222. }
  223. line[curr-1] = '\0';
  224. return line;
  225. }
  226. // Matrix stuff
  227. int count_fields(char *line)
  228. {
  229. int count = 0;
  230. int done = 0;
  231. char *c;
  232. for(c = line; !done; ++c){
  233. done = (*c == '\0');
  234. if(*c == ',' || done) ++count;
  235. }
  236. return count;
  237. }
  238. double *parse_fields(char *l, int n)
  239. {
  240. int i;
  241. double *field = calloc(n, sizeof(double));
  242. for(i = 0; i < n; ++i){
  243. field[i] = atof(l);
  244. l = strchr(l, ',')+1;
  245. }
  246. return field;
  247. }
  248. matrix make_matrix(int rows, int cols)
  249. {
  250. matrix m;
  251. m.rows = rows;
  252. m.cols = cols;
  253. m.vals = calloc(m.rows, sizeof(double *));
  254. int i;
  255. for(i = 0; i < m.rows; ++i) m.vals[i] = calloc(m.cols, sizeof(double));
  256. return m;
  257. }
  258. void zero_matrix(matrix m)
  259. {
  260. int i, j;
  261. for(i = 0; i < m.rows; ++i){
  262. for(j = 0; j < m.cols; ++j) m.vals[i][j] = 0;
  263. }
  264. }
  265. matrix csv_to_matrix(char *filename, int header)
  266. {
  267. FILE *fp = fopen(filename, "r");
  268. if(!fp) error(filename);
  269. matrix m;
  270. m.cols = -1;
  271. char *line;
  272. int n = 0;
  273. int size = 1024;
  274. m.vals = calloc(size, sizeof(double*));
  275. if(header) fgetl(fp);
  276. while((line = fgetl(fp))){
  277. if(m.cols == -1) m.cols = count_fields(line);
  278. if(n == size){
  279. size *= 2;
  280. m.vals = realloc(m.vals, size*sizeof(double*));
  281. }
  282. m.vals[n] = parse_fields(line, m.cols);
  283. free(line);
  284. ++n;
  285. }
  286. m.vals = realloc(m.vals, n*sizeof(double*));
  287. m.rows = n;
  288. return m;
  289. }
  290. // Arguement parsing
  291. void del_arg(int argc, char **argv, int index)
  292. {
  293. int i;
  294. for(i = index; i < argc-1; ++i) argv[i] = argv[i+1];
  295. argv[i] = 0;
  296. }
  297. int find_arg(int argc, char* argv[], char *arg)
  298. {
  299. int i;
  300. for(i = 0; i < argc; ++i) {
  301. if(!argv[i]) continue;
  302. if(0==strcmp(argv[i], arg)) {
  303. del_arg(argc, argv, i);
  304. return 1;
  305. }
  306. }
  307. return 0;
  308. }
  309. int find_int_arg(int argc, char **argv, char *arg, int def)
  310. {
  311. int i;
  312. for(i = 0; i < argc-1; ++i){
  313. if(!argv[i]) continue;
  314. if(0==strcmp(argv[i], arg)){
  315. def = atoi(argv[i+1]);
  316. del_arg(argc, argv, i);
  317. del_arg(argc, argv, i);
  318. break;
  319. }
  320. }
  321. return def;
  322. }
  323. float find_float_arg(int argc, char **argv, char *arg, float def)
  324. {
  325. int i;
  326. for(i = 0; i < argc-1; ++i){
  327. if(!argv[i]) continue;
  328. if(0==strcmp(argv[i], arg)){
  329. def = atof(argv[i+1]);
  330. del_arg(argc, argv, i);
  331. del_arg(argc, argv, i);
  332. break;
  333. }
  334. }
  335. return def;
  336. }
  337. char *find_char_arg(int argc, char **argv, char *arg, char *def)
  338. {
  339. int i;
  340. for(i = 0; i < argc-1; ++i){
  341. if(!argv[i]) continue;
  342. if(0==strcmp(argv[i], arg)){
  343. def = argv[i+1];
  344. del_arg(argc, argv, i);
  345. del_arg(argc, argv, i);
  346. break;
  347. }
  348. }
  349. return def;
  350. }