https://github.com/X-PLUG/mPLUG-Owl

Architecture overview

Untitled

처음으로는 f_v로 부터 dense image representation을 얻는다. 그러나, 이 dense feature는 fine-grained image 정보를 조각화(fragment)할 수 있고, f_l에 넣기에 긴 시퀀스 때문에 큰 연산을 가져온다. 이러한 현상을 해소하기 위해 visual abstractor module f_k를 차용했다. f_k는 visual information을 몇몇의 learnable token으로 요약하며, 그렇기 때문에 더 높은 visual representation을 얻을 수 있고 계산 비용도 절약이 가능하다. visual representation은 text query와 함께 병합되어 language모델에 입력된다.

코드 뜯어보기

processing

def _tokenize_prompt(prompt, tokenizer, add_BOS=False, media_info={"<image>": 65}, **kwargs):
		# dongyong: <image> 토큰 들어오면 -1로 매핑
    media_tokens = {k: -int(i + 1) for i, k in enumerate(media_info.keys())}
    media_lengths = media_info.copy()

		# dongyong: BOS를 쓰기 때문에 맨 앞에 BOS 추가
    if add_BOS:
        prompt_chunk = [tokenizer.bos_token_id]
    else:
        prompt_chunk = []
		
		# dongyong: media 토큰 없으면 걍 퓨어 토크나이즈
    # Pure Text
    if all([media_token not in prompt for media_token in media_tokens.keys()]):
        enc_chunk = prompt_chunk + tokenizer(prompt, add_special_tokens=False, **kwargs)["input_ids"]
	
		# dongyong: media 토큰 있으면
    # Multi-Modal Text
    else:
        enc_chunk = prompt_chunk
        # dongyong: media token들을 or 패턴 처리
        pattern = "|".join(map(re.escape, list(media_tokens.keys())))
        # dongyong: media token을 기준으로 split
        chunk_strs = re.split(f"({pattern})", prompt) 
        chunk_strs = [x for x in chunk_strs if len(x) > 0]
        for idx, chunk_str in enumerate(chunk_strs):
						# dongyong: chunk_str가 media_token이라면 image는 65칸 만큼 -1 token 추가
            if chunk_str in media_tokens:
                enc_chunk += [media_tokens[chunk_str]] * media_lengths[chunk_str]
            else:
                tmp_chunk = tokenizer(chunk_str, add_special_tokens=False)["input_ids"]
                # if idx < len(chunk_strs) - 1: # Last chunk should not have eos
                #     tmp_chunk += [tokenizer.eod_id]
                enc_chunk += tmp_chunk
    return enc_chunk

Model

class MplugOwlForConditionalGeneration(MplugOwlPreTrainedModel):
		...

    def __init__(self, config: MplugOwlConfig):
        super().__init__(config)

        self.vision_model = MplugOwlVisionModel(config.vision_config)
				# dongyong: [1, 64, 1024]
        self.query_tokens = nn.Parameter(
            torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
        )
        self.abstractor = MplugOwlVisualAbstractorModel(
            config.visual_abstractor_config, config.text_config.hidden_size
        )
        language_model = AutoModelForCausalLM.from_config(config.text_config)
        self.language_model = language_model

        # Initialize weights and apply final processing
        self.post_init()
        self.main_input_name = "input_ids"
        from transformers import GenerationConfig

        self.generation_config = GenerationConfig(
            max_length=512, do_sample=True, top_k=3, pad_token_id=0, unk_token_id=0, bos_token_id=1, eos_token_id=2
        )