retinaface.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. from maix import nn, camera, image, display
  2. from maix.nn.app.face import FaceRecognize
  3. import time
  4. from evdev import InputDevice
  5. from select import select
  6. score_threshold = 70 #识别分数阈值
  7. input_size = (224, 224, 3) #输入图片尺寸
  8. input_size_fe = (128, 128, 3) #输入人脸数据
  9. feature_len = 256 #人脸数据宽度
  10. steps = [8, 16, 32] #
  11. channel_num = 0 #通道数量
  12. users = [] #初始化用户列表
  13. names = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"] #人脸标签定义
  14. model = {
  15. "param": "/home/model/face_recognize/model_int8.param",
  16. "bin": "/home/model/face_recognize/model_int8.bin"
  17. }
  18. model_fe = {
  19. "param": "/home/model/face_recognize/fe_res18_117.param",
  20. "bin": "/home/model/face_recognize/fe_res18_117.bin"
  21. }
  22. for i in range(len(steps)):
  23. channel_num += input_size[1] / steps[i] * (input_size[0] / steps[i]) * 2
  24. channel_num = int(channel_num) #统计通道数量
  25. options = { #准备人脸输出参数
  26. "model_type": "awnn",
  27. "inputs": {
  28. "input0": input_size
  29. },
  30. "outputs": {
  31. "output0": (1, 4, channel_num) ,
  32. "431": (1, 2, channel_num) ,
  33. "output2": (1, 10, channel_num)
  34. },
  35. "mean": [127.5, 127.5, 127.5],
  36. "norm": [0.0078125, 0.0078125, 0.0078125],
  37. }
  38. options_fe = { #准备特征提取参数
  39. "model_type": "awnn",
  40. "inputs": {
  41. "inputs_blob": input_size_fe
  42. },
  43. "outputs": {
  44. "FC_blob": (1, 1, feature_len)
  45. },
  46. "mean": [127.5, 127.5, 127.5],
  47. "norm": [0.0078125, 0.0078125, 0.0078125],
  48. }
  49. keys = InputDevice('/dev/input/event0')
  50. threshold = 0.5 #人脸阈值
  51. nms = 0.3
  52. max_face_num = 1 #输出的画面中的人脸的最大个数
  53. print("-- load model:", model)
  54. m = nn.load(model, opt=options)
  55. print("-- load ok")
  56. print("-- load model:", model_fe)
  57. m_fe = nn.load(model_fe, opt=options_fe)
  58. print("-- load ok")
  59. face_recognizer = FaceRecognize(m, m_fe, feature_len, input_size, threshold, nms, max_face_num)
  60. def get_key(): #按键检测函数
  61. r,w,x = select([keys], [], [],0)
  62. if r:
  63. for event in keys.read():
  64. if event.value == 1 and event.code == 0x02: # 右键
  65. return 1
  66. elif event.value == 1 and event.code == 0x03: # 左键
  67. return 2
  68. elif event.value == 2 and event.code == 0x03: # 左键连按
  69. return 3
  70. return 0
  71. def map_face(box,points): #将224*224空间的位置转换到240*240或320*240空间内
  72. # print(box,points)
  73. if display.width() == display.height():
  74. def tran(x):
  75. return int(x/224*display.width())
  76. box = list(map(tran, box))
  77. def tran_p(p):
  78. return list(map(tran, p))
  79. points = list(map(tran_p, points))
  80. else:
  81. # 168x224(320x240) > 224x224(240x240) > 320x240
  82. s = (224*display.height()/display.width()) # 168x224
  83. w, h, c = display.width()/224, display.height()/224, 224/s
  84. t, d = c*h, (224 - s) // 2 # d = 224 - s // 2 == 28
  85. box[0], box[1], box[2], box[3] = int(box[0]*w), int((box[1]-28)*t), int(box[2]*w), int((box[3])*t)
  86. def tran_p(p):
  87. return [int(p[0]*w), int((p[1]-d)*t)] # 224 - 168 / 2 = 28 so 168 / (old_h - 28) = 240 / new_h
  88. points = list(map(tran_p, points))
  89. # print(box,points)
  90. return box,points
  91. def darw_info(draw, box, points, disp_str, bg_color=(255, 0, 0), font_color=(255, 255, 255)): #画框函数
  92. box,points = map_face(box,points)
  93. font_wh = image.get_string_size(disp_str)
  94. for p in points:
  95. draw.draw_rectangle(p[0] - 1, p[1] -1, p[0] + 1, p[1] + 1, color=bg_color)
  96. draw.draw_rectangle(box[0], box[1], box[0] + box[2], box[1] + box[3], color=bg_color, thickness=2)
  97. draw.draw_rectangle(box[0], box[1] - font_wh[1], box[0] + font_wh[0], box[1], color=bg_color, thickness = -1)
  98. draw.draw_string(box[0], box[1] - font_wh[1], disp_str, color=font_color)
  99. def recognize(feature): #进行人脸匹配
  100. def _compare(user): #定义映射函数
  101. return face_recognizer.compare(user, feature) #推测匹配分数 score相关分数
  102. face_score_l = list(map(_compare,users)) #映射特征数据在记录中的比对分数
  103. return max(enumerate(face_score_l), key=lambda x: x[-1]) #提取出人脸分数最大值和最大值所在的位置
  104. def run():
  105. img = camera.capture() #获取224*224*3的图像数据
  106. AI_img = img.copy().resize(224, 224)
  107. if not img:
  108. time.sleep(0.02)
  109. return
  110. faces = face_recognizer.get_faces(AI_img.tobytes(),False) #提取人脸特征信息
  111. if faces:
  112. for prob, box, landmarks, feature in faces:
  113. key_val = get_key()
  114. if key_val == 1: # 右键添加人脸记录
  115. if len(users) < len(names):
  116. print("add user:", len(users))
  117. users.append(feature)
  118. else:
  119. print("user full")
  120. elif key_val == 2: # 左键删除人脸记录
  121. if len(users) > 0:
  122. print("remove user:", names[len(users) - 1])
  123. users.pop()
  124. else:
  125. print("user empty")
  126. if len(users): #判断是否记录人脸
  127. maxIndex = recognize(feature)
  128. if maxIndex[1] > score_threshold: #判断人脸识别阈值,当分数大于阈值时认为是同一张脸,当分数小于阈值时认为是相似脸
  129. darw_info(img, box, landmarks, "{}:{:.2f}".format(names[maxIndex[0]], maxIndex[1]), font_color=(0, 0, 255, 255), bg_color=(0, 255, 0, 255))
  130. print("user: {}, score: {:.2f}".format(names[maxIndex[0]], maxIndex[1]))
  131. else:
  132. darw_info(img, box, landmarks, "{}:{:.2f}".format(names[maxIndex[0]], maxIndex[1]), font_color=(255, 255, 255, 255), bg_color=(255, 0, 0, 255))
  133. print("maybe user: {}, score: {:.2f}".format(names[maxIndex[0]], maxIndex[1]))
  134. else: #没有记录脸
  135. darw_info(img, box, landmarks, "error face", font_color=(255, 255, 255, 255), bg_color=(255, 0, 0, 255))
  136. display.show(img)
  137. if __name__ == "__main__":
  138. import signal
  139. def handle_signal_z(signum,frame):
  140. print("APP OVER")
  141. exit(0)
  142. signal.signal(signal.SIGINT,handle_signal_z)
  143. while True:
  144. run()