-
2. 신경망(6)AI 모델(딥러닝 기초)/2. 신경망 2023. 1. 5. 14:00728x90
※ Batch 처리
- 입력 데이터를 하나로 묶어 처리함.
- 수치 계산 라이브러리 대부분이 큰 배열을 효율적으로 처리할 수 있도록 고도로 최적화 되어 있어 시간 효율 ↑
- 신경망의 데이터 전송 병목 현상을 줄여줌. 데이터 리딩 횟수가 줄어 CPU나 GPU로 순수 계산 수행을 할 수 있음.
- 배치 처리를 통해 큰 배열로 이루어진 계산을 하게 되는데, 컴퓨터에서는 큰 배열을 한꺼번에 계산하는 것이 효율적임.
x, t = get_data() network = init_network() batch_size = 100 accuracy_cnt = 0 for i in range(0, len(x), batch_size): x_batch = x[i:i+batch_size] y_batch = predict(network, x_batch) p= np.argmax(y_batch, axis=1) # 확률이 가장 높은 원소의 인덱스를 얻는다. accuracy_cnt += np.sum(p == t[i:i+batch_size]) print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
이전 신경망(5)에서 mnist 데이터셋을 predict하여 accuracy를 출력하는 과정을 batch 처리를 통해 개선시켰다.
■ x[i:i+batch_size] => 입력 데이터의 i~i+batch_size번째까지의 데이터를 묶는 것. batch_size = 100이므로 x[0:100], x[100:200]과 같은 의미로 100장씩 묶어서 꺼내게 된다.
■ argmax(y_batch, axis=1) 의 경우 1번째 차원을 축으로 최댓값의 index를 찾도록 한 것.
■ np.sum(p == t[i:i+batch_size]) => () 안의 조건에 맞는 index를 boolean 값으로 불러오고 True의 갯수를 세서 더해준 값
이렇게 데이터를 batch 형태로 처리함으로써 효율적이고 더 빠른 구현이 가능하다.
728x90