https://github.com/X-PLUG/mPLUG-Owl
처음으로는 f_v로 부터 dense image representation을 얻는다. 그러나, 이 dense feature는 fine-grained image 정보를 조각화(fragment)할 수 있고, f_l에 넣기에 긴 시퀀스 때문에 큰 연산을 가져온다. 이러한 현상을 해소하기 위해 visual abstractor module f_k를 차용했다. f_k는 visual information을 몇몇의 learnable token으로 요약하며, 그렇기 때문에 더 높은 visual representation을 얻을 수 있고 계산 비용도 절약이 가능하다. visual representation은 text query와 함께 병합되어 language모델에 입력된다.
def _tokenize_prompt(prompt, tokenizer, add_BOS=False, media_info={"<image>": 65}, **kwargs):
# dongyong: <image> 토큰 들어오면 -1로 매핑
media_tokens = {k: -int(i + 1) for i, k in enumerate(media_info.keys())}
media_lengths = media_info.copy()
# dongyong: BOS를 쓰기 때문에 맨 앞에 BOS 추가
if add_BOS:
prompt_chunk = [tokenizer.bos_token_id]
else:
prompt_chunk = []
# dongyong: media 토큰 없으면 걍 퓨어 토크나이즈
# Pure Text
if all([media_token not in prompt for media_token in media_tokens.keys()]):
enc_chunk = prompt_chunk + tokenizer(prompt, add_special_tokens=False, **kwargs)["input_ids"]
# dongyong: media 토큰 있으면
# Multi-Modal Text
else:
enc_chunk = prompt_chunk
# dongyong: media token들을 or 패턴 처리
pattern = "|".join(map(re.escape, list(media_tokens.keys())))
# dongyong: media token을 기준으로 split
chunk_strs = re.split(f"({pattern})", prompt)
chunk_strs = [x for x in chunk_strs if len(x) > 0]
for idx, chunk_str in enumerate(chunk_strs):
# dongyong: chunk_str가 media_token이라면 image는 65칸 만큼 -1 token 추가
if chunk_str in media_tokens:
enc_chunk += [media_tokens[chunk_str]] * media_lengths[chunk_str]
else:
tmp_chunk = tokenizer(chunk_str, add_special_tokens=False)["input_ids"]
# if idx < len(chunk_strs) - 1: # Last chunk should not have eos
# tmp_chunk += [tokenizer.eod_id]
enc_chunk += tmp_chunk
return enc_chunk
class MplugOwlForConditionalGeneration(MplugOwlPreTrainedModel):
...
def __init__(self, config: MplugOwlConfig):
super().__init__(config)
self.vision_model = MplugOwlVisionModel(config.vision_config)
# dongyong: [1, 64, 1024]
self.query_tokens = nn.Parameter(
torch.zeros(1, config.num_query_tokens, config.visual_abstractor_config.hidden_size)
)
self.abstractor = MplugOwlVisualAbstractorModel(
config.visual_abstractor_config, config.text_config.hidden_size
)
language_model = AutoModelForCausalLM.from_config(config.text_config)
self.language_model = language_model
# Initialize weights and apply final processing
self.post_init()
self.main_input_name = "input_ids"
from transformers import GenerationConfig
self.generation_config = GenerationConfig(
max_length=512, do_sample=True, top_k=3, pad_token_id=0, unk_token_id=0, bos_token_id=1, eos_token_id=2
)