import os import torch import torch.nn as nn from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer from datasets import Dataset from PIL import Image import re import pytesseract import docx2txt import PyPDF2 import json from huggingface_hub import login login(f"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") os.environ["TOKENIZERS_PARALLELISM"] = "false" def load_file_catalog(catalog_path): with open(catalog_path, 'r', encoding='utf-8') as file: return json.load(file) def identify_legal_document(filename, file_catalog): return file_catalog.get(filename, f"") # Funkcja do ekstrakcji tekstu z różnych typów plików def extract_text_from_file(file_path): _, ext = os.path.splitext(file_path) ext = ext.lower() if ext in ['.txt', '.md']: with open(file_path, 'r', encoding='utf-8') as file: return file.read() elif ext == '.pdf': text = "" with open(file_path, 'rb') as file: reader = PyPDF2.PdfReader(file) for page in reader.pages: text += page.extract_text() return text elif ext in ['.doc', '.docx']: return docx2txt.process(file_path) elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: return pytesseract.image_to_string(Image.open(file_path)) else: return "" # Przygotowanie danych def prepare_dataset(directory, catalog_path): file_catalog = load_file_catalog(catalog_path) data = [] for root, _, files in os.walk(directory): for file in files: file_path = os.path.join(root, file) text = extract_text_from_file(file_path) if text: # Sprawdzenie, czy plik znajduje się w katalogu doc_type = identify_legal_document(file, file_catalog) if doc_type != "Opracowanie własne": # Przetwarzanie dla aktów prawnych articles = re.split(r'(Art\.\s+\d+\.)', text)[1:] for i in range(0, len(articles), 2): if i + 1 < len(articles): article_number = articles[i].strip() article_content = articles[i + 1].strip() data.append({ "text": f"{article_number} {article_content}", "source": f"{doc_type}, {article_number}" }) else: # Przetwarzanie dla zwykłych dokumentów chunks = [text[i:i + 512] for i in range(0, len(text), 512)] for chunk in chunks: data.append({ "text": chunk, "source": f"" }) return data # Tokenizacja danych def tokenize_function(examples): inputs = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512) inputs["labels"] = inputs["input_ids"].copy() inputs["source"] = examples["source"] return inputs # Dostosowany model class CustomModel(AutoModelForCausalLM): def __init__(self, config): super().__init__(config) self.source_embedding = nn.Embedding(1000, config.hidden_size) # Zakładamy maksymalnie 1000 różnych źródeł def forward(self, input_ids, attention_mask=None, labels=None, source=None): outputs = super().forward(input_ids, attention_mask=attention_mask, labels=labels) if source is not None: source_embeds = self.source_embedding(source) outputs.logits += source_embeds.unsqueeze(1) return outputs # Dostosowany Trainer class CustomTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") source = inputs.pop("source") outputs = model(**inputs, labels=labels) loss = outputs.loss return (loss, outputs) if return_outputs else loss # Przygotowanie modelu i tokenizera model_name = "google/gemma-2-2b" tokenizer = AutoTokenizer.from_pretrained(model_name) model = CustomModel.from_pretrained(model_name) # Przygotowanie datasetu catalog_path = "file_catalog.json" data = prepare_dataset("files", catalog_path) dataset = Dataset.from_list(data) tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) # Konfiguracja treningu training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=4, save_steps=10_000, save_total_limit=2, ) # Inicjalizacja Trainera trainer = CustomTrainer( model=model, args=training_args, train_dataset=tokenized_dataset, ) # Trening modelu trainer.train() # Zapisanie modelu trainer.save_model("./gemma2_finetuned") # Funkcja do generowania odpowiedzi z cytowaniem def generate_answer(question, model, tokenizer, dataset): inputs = tokenizer(question, return_tensors="pt") outputs = model.generate(**inputs, max_length=200, num_return_sequences=1, output_scores=True, return_dict_in_generate=True) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) # Znajdź najbardziej prawdopodobne źródło source_probs = outputs.scores[-1][:, model.source_embedding.weight.shape[0]:] most_likely_source_idx = torch.argmax(source_probs).item() most_likely_source = dataset[most_likely_source_idx % len(dataset)]['source'] return f"{answer}\n\nŹródło: {most_likely_source}" # Przykład użycia question = "Ile dni urlopu przysługuje pracownikowi?" answer = generate_answer(question, model, tokenizer, dataset) print(answer)