ably.do/hft.py

215 lines
7.3 KiB
Python

import os
import torch
import torch.nn as nn
from transformers import GPTNeoForCausalLM, Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from datasets import Dataset
from PIL import Image
import re
import pytesseract
import docx2txt
import PyPDF2
import json
from torch.cuda.amp import autocast
from collections import defaultdict
from huggingface_hub import login
torch.cuda.empty_cache()
# Logowanie do Hugging Face Hub
login(token="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")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
def free_memory():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
class SourceMapper:
def __init__(self):
self.source_to_idx = defaultdict(lambda: 0) # Domyślnie 0 dla nieznanych
self.idx_to_source = {0: "Unknown"}
self.next_idx = 1 # Indeksy od 1 dla znanych źródeł
def add_source(self, source):
if source and source not in self.source_to_idx:
idx = self.next_idx
self.source_to_idx[source] = idx
self.idx_to_source[idx] = source
self.next_idx += 1
def get_idx(self, source):
return self.source_to_idx.get(source, 0)
def get_source(self, idx):
return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path):
with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file)
def identify_legal_document(filename, file_catalog):
return file_catalog.get(filename, "Opracowanie własne")
def extract_text_from_file(file_path):
_, ext = os.path.splitext(file_path)
ext = ext.lower()
if ext in ['.txt', '.md']:
with open(file_path, 'r', encoding='utf-8') as file:
return file.read()
elif ext == '.pdf':
text = ""
with open(file_path, 'rb') as file:
reader = PyPDF2.PdfReader(file)
for page in reader.pages:
text += page.extract_text() or ""
return text
elif ext in ['.doc', '.docx']:
return docx2txt.process(file_path)
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path))
else:
return ""
def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path)
data = []
for root, _, files in os.walk(directory):
for file in files:
file_path = os.path.join(root, file)
text = extract_text_from_file(file_path)
if not text:
continue
doc_type = identify_legal_document(file, file_catalog)
if doc_type != "Opracowanie własne":
articles = re.split(r'(Art\.\s+\d+\.)', text)
for i in range(1, len(articles), 2):
article_number = articles[i].strip()
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
data.append({
"text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source)
})
else:
chunks = [text[i:i+512] for i in range(0, len(text), 512)]
for chunk in chunks:
data.append({
"text": chunk,
"source_idx": 0
})
return data
def tokenize_function(examples):
tokenized = tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=512,
return_tensors="pt"
)
tokenized["labels"] = tokenized["input_ids"].clone()
tokenized["source_idx"] = examples["source_idx"]
return tokenized
def custom_collate_fn(batch):
input_ids = torch.stack([torch.tensor(b["input_ids"]) for b in batch])
attention_mask = torch.stack([torch.tensor(b["attention_mask"]) for b in batch])
labels = torch.stack([torch.tensor(b["labels"]) for b in batch])
source_idx = torch.tensor([b.get("source_idx", -1) for b in batch], dtype=torch.long)
return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "source_idx": source_idx}
class CustomModel(GPTNeoForCausalLM):
def __init__(self, config):
super().__init__(config)
self.source_embedding = nn.Embedding(
num_embeddings=1000,
embedding_dim=config.hidden_size,
padding_idx=0 # Poprawiony padding_idx
)
self.source_proj = nn.Linear(config.hidden_size, config.vocab_size)
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
with autocast():
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
if source_idx is not None:
source_embeds = self.source_embedding(source_idx)
source_projected = self.source_proj(source_embeds)
outputs.logits += source_projected.unsqueeze(1)
return outputs
source_mapper = SourceMapper()
model_name = "EleutherAI/gpt-neo-1.3B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
data = prepare_dataset("files", "file_catalog.json", source_mapper)
dataset = Dataset.from_list(data)
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=16)
config = AutoModelForCausalLM.from_pretrained(model_name).config
model = CustomModel.from_pretrained(model_name)
model.config.gradient_checkpointing = True
model.config.use_cache = False
model.resize_token_embeddings(len(tokenizer))
model.gradient_checkpointing_enable()
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=3,
gradient_accumulation_steps=8,
learning_rate=2e-5,
fp16=True,
logging_steps=50,
save_strategy="steps",
save_steps=500,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
logging_dir='./logs'
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=custom_collate_fn
)
trainer.train()
free_memory()
# Funkcja generująca odpowiedź
def generate_answer(question, model, tokenizer, source_mapper, max_length=200):
inputs = tokenizer(question, return_tensors="pt", truncation=True, max_length=512)
outputs = model.generate(
**inputs,
max_length=max_length,
num_return_sequences=1,
return_dict_in_generate=True,
output_scores=True,
)
answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Pobierz źródło z ostatniego tokena
last_token_logits = outputs.scores[-1]
source_idx = torch.argmax(last_token_logits, dim=-1)[-1].item()
source = source_mapper.get_source(source_idx)
return f"{answer}\n\nŹródło: {source if source else 'Opracowanie własne'}"
# Przykład użycia
question = "Ile dni urlopu przysługuje pracownikowi?"
answer = generate_answer(question, model, tokenizer, source_mapper)
print(answer)