log_parser.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. # -*- coding: utf-8 -*-
  2. # @Time : 2018/4/25 20:28
  3. # @Author : Adesun
  4. # @Site : https://github.com/Adesun
  5. # @File : log_parser.py
  6. import argparse
  7. import logging
  8. import os
  9. import platform
  10. import re
  11. import sys
  12. # set non-interactive backend default when os is not windows
  13. if sys.platform != 'win32':
  14. import matplotlib
  15. matplotlib.use('Agg')
  16. import matplotlib.pyplot as plt
  17. from matplotlib.ticker import MultipleLocator, FormatStrFormatter
  18. def get_file_name_and_ext(filename):
  19. (file_path, temp_filename) = os.path.split(filename)
  20. (file_name, file_ext) = os.path.splitext(temp_filename)
  21. return file_name, file_ext
  22. def show_message(message, stop=False):
  23. print(message)
  24. if stop:
  25. sys.exit(0)
  26. def parse_args():
  27. parser = argparse.ArgumentParser(description="training log parser by DeepKeeper ")
  28. parser.add_argument('--source-dir', dest='source_dir', type=str, default='./',
  29. help='the log source directory')
  30. parser.add_argument('--save-dir', dest='save_dir', type=str, default='./',
  31. help='the directory to be saved')
  32. parser.add_argument('--csv-file', dest='csv_file', type=str, default="",
  33. help='training log file')
  34. parser.add_argument('--log-file', dest='log_file', type=str, default="",
  35. help='training log file')
  36. parser.add_argument('--show', dest='show_plot', type=bool, default=False,
  37. help='whether to show')
  38. return parser.parse_args()
  39. def log_parser(args):
  40. if not args.log_file:
  41. show_message('log file must be specified.', True)
  42. log_path = os.path.join(args.source_dir, args.log_file)
  43. if not os.path.exists(log_path):
  44. show_message('log file does not exist.', True)
  45. file_name, _ = get_file_name_and_ext(log_path)
  46. log_content = open(log_path).read()
  47. iterations = []
  48. losses = []
  49. fig, ax = plt.subplots()
  50. # set area we focus on
  51. ax.set_ylim(0, 8)
  52. major_locator = MultipleLocator()
  53. minor_locator = MultipleLocator(0.5)
  54. ax.yaxis.set_major_locator(major_locator)
  55. ax.yaxis.set_minor_locator(minor_locator)
  56. ax.yaxis.grid(True, which='minor')
  57. pattern = re.compile(r"([\d].*): .*?, (.*?) avg")
  58. # print(pattern.findall(log_content))
  59. matches = pattern.findall(log_content)
  60. # print(type(matches[0]))
  61. counter = 0
  62. log_count = len(matches)
  63. if args.csv_file != '':
  64. csv_path = os.path.join(args.save_dir, args.csv_file)
  65. out_file = open(csv_path, 'w')
  66. else:
  67. csv_path = os.path.join(args.save_dir, file_name + '.csv')
  68. out_file = open(csv_path, 'w')
  69. for match in matches:
  70. counter += 1
  71. if log_count > 200:
  72. if counter % 200 == 0:
  73. print('parsing {}/{}'.format(counter, log_count))
  74. else:
  75. print('parsing {}/{}'.format(counter, log_count))
  76. iteration, loss = match
  77. iterations.append(int(iteration))
  78. losses.append(float(loss))
  79. out_file.write(iteration + ',' + loss + '\n')
  80. ax.plot(iterations, losses)
  81. plt.xlabel('Iteration')
  82. plt.ylabel('Loss')
  83. plt.tight_layout()
  84. # saved as svg
  85. save_path = os.path.join(args.save_dir, file_name + '.svg')
  86. plt.savefig(save_path, dpi=300, format="svg")
  87. if args.show_plot:
  88. plt.show()
  89. if __name__ == "__main__":
  90. args = parse_args()
  91. log_parser(args)