import os import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import load_dataset from PIL import Image import re import pytesseract import docx2txt import PyPDF2 from huggingface_hub import login login(f"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") def load_file_catalog(catalog_path): with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) def identify_legal_document(filename, file_catalog): return file_catalog.get(filename, f"") # Funkcja do ekstrakcji tekstu z różnych typów plików def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" # Przygotowanie danych def prepare_dataset(directory, catalog_path): file_catalog = load_file_catalog(catalog_path) data = [] for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) text = extract_text_from_file(file_path) if text: # Sprawdzenie, czy plik znajduje się w katalogu doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": # Przetwarzanie dla aktów prawnych articles = re.split(r'(Art\.\s+\d+\.)', text)[1:] for i in range(0, len(articles), 2): if i + 1 < len(articles): article_number = articles[i].strip() article_content = articles[i + 1].strip() data.append({ "text": f"{article_number} {article_content}", "source": f"{doc_type}, {article_number}" }) else: # Przetwarzanie dla zwykłych dokumentów chunks = [text[i:i + 512] for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source": f"" }) return data # Tokenizacja danych def tokenize_function(examples): inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512) inputs["source"] = examples["source"] return inputs # Dostosowany model class CustomModel(AutoModelForCausalLM): def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding(1000, config.hidden_size) # Zakładamy maksymalnie 1000 różnych źródeł def forward(self, input_ids, attention_mask=None, labels=None, source=None): outputs = super().forward(input_ids, attention_mask=attention_mask, labels=labels) if source is not None: source_embeds = self.source_embedding(source) outputs.logits += source_embeds.unsqueeze(1) return outputs # Dostosowany Trainer class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels") source = inputs.pop("source") outputs = model(**inputs, labels=labels, source=source) loss = outputs.loss return (loss, outputs) if return_outputs else loss # Przygotowanie modelu i tokenizera model_name = "google/gemma-2-2b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = CustomModel.from_pretrained(model_name) # Przygotowanie datasetu catalog_path = "file_catalog.json" data = prepare_dataset("files") dataset = load_dataset("dict", data=data) tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=4, save_steps=10_000, save_total_limit=2, ) # Inicjalizacja Trainera trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset["train"], ) # Trening modelu trainer.train() # Zapisanie modelu trainer.save_model("./gemma2_finetuned") # Funkcja do generowania odpowiedzi z cytowaniem def generate_answer(question, model, tokenizer, dataset): inputs = tokenizer(question, return_tensors="pt") outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, output_scores=True, return_dict_in_generate=True) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Znajdź najbardziej prawdopodobne źródło source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:] most_likely_source_idx = torch.argmax(source_probs).item() most_likely_source = dataset[most_likely_source_idx]['source'] return f"{answer}\n\nŹródło: {most_likely_source}" # Przykład użycia question = "Ile dni urlopu przysługuje pracownikowi?" answer = generate_answer(question, model, tokenizer, dataset) print(answer)