Coding 공부/Python

[Python] 드론제어 예제코드(상세) Python3, ROS2, OpenCV, YOLO, DJITelloPy

CBJH 2024. 7. 20.
728x90
반응형
  • 코드 간략 설명 : 드론으로부터 실시간 비디오 스트림을 받아 YOLO를 통해 사람 객체를 탐지하고, 탐지된 객체를 표시 및 저장하며, 결과 이미지를 ROS 주제로 퍼블리싱하는 코드
import cv2
import numpy as np
import rospy
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from djitellopy import Tello

# YOLO 모델 설정
YOLO_WEIGHTS = "yolov3.weights"  # YOLO 가중치 파일 경로
YOLO_CONFIG = "yolov3.cfg"       # YOLO 구성 파일 경로
YOLO_CLASSES = "coco.names"      # 객체 클래스 이름 파일 경로

# YOLO 네트워크 로드
net = cv2.dnn.readNet(YOLO_WEIGHTS, YOLO_CONFIG)
with open(YOLO_CLASSES, "r") as f:
    classes = [line.strip() for line in f.readlines()]  # 클래스 이름 목록 읽기
layer_names = net.getLayerNames()  # YOLO 네트워크의 레이어 이름 가져오기
output_layers = [layer_names[i[0] - 1] for i in net.getUnconnectedOutLayers()]  # 출력 레이어 설정

# ROS 노드 초기화
rospy.init_node('drone_controller', anonymous=True)
bridge = CvBridge()  # OpenCV와 ROS 간의 이미지 변환을 위한 브릿지

# Tello 드론 초기화
drone = Tello()
drone.connect()  # 드론 연결
drone.streamon()  # 비디오 스트림 시작

def capture_and_detect():
    frame_read = drone.get_frame_read()  # 드론으로부터 프레임 읽기
    frame = frame_read.frame  # 프레임 추출

    height, width, channels = frame.shape  # 프레임의 높이, 너비, 채널 수 얻기
    # YOLO에 입력할 블롭 생성
    blob = cv2.dnn.blobFromImage(frame, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
    net.setInput(blob)  # YOLO 네트워크에 블롭 설정
    outs = net.forward(output_layers)  # YOLO 네트워크에서 결과 추론

    class_ids = []  # 클래스 ID 리스트
    confidences = []  # 신뢰도 리스트
    boxes = []  # 바운딩 박스 리스트

    # YOLO의 각 출력 레이어 결과 처리
    for out in outs:
        for detection in out:
            scores = detection[5:]  # 각 클래스에 대한 점수
            class_id = np.argmax(scores)  # 최고 점수를 가진 클래스 ID
            confidence = scores[class_id]  # 최고 점수 (신뢰도)
            if confidence > 0.5 and classes[class_id] == "person":  # 신뢰도가 0.5 이상이고 클래스가 'person'일 경우
                center_x = int(detection[0] * width)  # 바운딩 박스 중심 x 좌표
                center_y = int(detection[1] * height)  # 바운딩 박스 중심 y 좌표
                w = int(detection[2] * width)  # 바운딩 박스 너비
                h = int(detection[3] * height)  # 바운딩 박스 높이
                x = int(center_x - w / 2)  # 바운딩 박스 왼쪽 상단 x 좌표
                y = int(center_y - h / 2)  # 바운딩 박스 왼쪽 상단 y 좌표
                boxes.append([x, y, w, h])  # 바운딩 박스 좌표 추가
                confidences.append(float(confidence))  # 신뢰도 추가
                class_ids.append(class_id)  # 클래스 ID 추가

    # Non-maximum Suppression 적용하여 중복 박스 제거
    indexes = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)

    for i in range(len(boxes)):
        if i in indexes:  # 유효한 바운딩 박스일 경우
            x, y, w, h = boxes[i]
            label = str(classes[class_ids[i]])  # 클래스 이름 레이블
            cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)  # 바운딩 박스 그리기
            cv2.putText(frame, label, (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)  # 레이블 텍스트 표시
            person_img = frame[y:y+h, x:x+w]  # 사람 이미지 추출
            cv2.imwrite(f"captured_person_{i}.jpg", person_img)  # 추출된 이미지 저장

    cv2.imshow("Frame", frame)  # 프레임 화면에 표시
    cv2.waitKey(1)  # 1ms 대기

    img_msg = bridge.cv2_to_imgmsg(frame, "bgr8")  # OpenCV 이미지를 ROS 이미지 메시지로 변환
    image_pub.publish(img_msg)  # ROS 주제로 이미지 메시지 퍼블리싱

# ROS 이미지 퍼블리셔 설정
image_pub = rospy.Publisher('/drone/image_raw', Image, queue_size=10)

if __name__ == "__main__":
    try:
        while not rospy.is_shutdown():  # ROS 노드가 종료될 때까지 반복
            capture_and_detect()  # 캡처 및 객체 인식 함수 호출
    except rospy.ROSInterruptException:
        pass
    finally:
        drone.streamoff()  # 드론 스트림 종료
        cv2.destroyAllWindows()  # 모든 OpenCV 창 닫기
함수 동작 설명
  1. 프레임 읽기 및 블롭 생성:
    • frame_read = drone.get_frame_read(): 드론으로부터 비디오 스트림을 읽습니다.
    • frame = frame_read.frame: 현재 프레임을 가져옵니다.
    • blob = cv2.dnn.blobFromImage(...): 프레임을 YOLO 네트워크의 입력 형식에 맞게 변환합니다.
  2. YOLO 네트워크를 통한 객체 탐지:
    • net.setInput(blob): YOLO 네트워크에 블롭 데이터를 설정합니다.
    • outs = net.forward(output_layers): YOLO 네트워크를 통해 객체 탐지를 수행하고 결과를 얻습니다.
  3. 탐지된 객체 처리:
    • scores = detection[5:]: 각 클래스에 대한 점수를 추출합니다.
    • class_id = np.argmax(scores): 최고 점수를 가진 클래스 ID를 얻습니다.
    • confidence = scores[class_id]: 해당 클래스의 신뢰도를 얻습니다.
    • 신뢰도가 0.5 이상이고 클래스가 'person'인 경우 바운딩 박스를 계산하여 추가합니다.
  4. Non-maximum Suppression:
    • indexes = cv2.dnn.NMSBoxes(...): 중복된 바운딩 박스를 제거하여 최종 바운딩 박스를 선택합니다.
  5. 탐지 결과 표시 및 저장:
    • 유효한 바운딩 박스에 대해:
      • cv2.rectangle(...): 프레임에 바운딩 박스를 그립니다.
      • cv2.putText(...): 클래스 이름 레이블을 표시합니다.
      • person_img = frame[y:y+h, x:x+w]: 사람 이미지를 추출하여 저장합니다.
    • cv2.imshow("Frame", frame): 프레임을 화면에 표시합니다.
    • cv2.waitKey(1): 1ms 동안 대기하여 화면을 갱신합니다.
  6. ROS 주제로 이미지 퍼블리싱:
    • img_msg = bridge.cv2_to_imgmsg(...): OpenCV 이미지를 ROS 이미지 메시지로 변환합니다.
    • image_pub.publish(img_msg): ROS 주제로 이미지 메시지를 퍼블리싱합니다.

이렇게 작성된 코드는 드론으로부터 실시간 비디오 스트림을 받아 YOLO를 통해 사람 객체를 탐지하고, 탐지된 객체를 표시 및 저장하며, 결과 이미지를 ROS 주제로 퍼블리싱합니다. 추가로 궁금한 점이 있으면 언제든지 말씀해 주세요!

댓글