This commit is contained in:
l.gabrysiak 2025-02-25 21:30:01 +01:00
parent 0b00a502db
commit f97eeea435
3 changed files with 124 additions and 63 deletions

View File

@ -2,7 +2,7 @@ USTAWA
z dnia 26 czerwca 1974 r. z dnia 26 czerwca 1974 r.
Kodeks pracy1) Kodeks pracy
(Dz. U. z 2023 r. poz. 1465 oraz z 2024 r. poz. 878, 1222, 1871 i 1965) (Dz. U. z 2023 r. poz. 1465 oraz z 2024 r. poz. 878, 1222, 1871 i 1965)
@ -11,8 +11,6 @@ obowiązuje od dnia 1 stycznia 1975 r.
historia od dnia 16 lutego 1998 r. historia od dnia 16 lutego 1998 r.
Preambuła (uchylona)
DZIAŁ PIERWSZY DZIAŁ PIERWSZY
Przepisy ogólne Przepisy ogólne

119
hft.py
View File

@ -15,7 +15,7 @@ from huggingface_hub import login
# Konfiguracja # Konfiguracja
os.environ['TORCH_USE_CUDA_DSA'] = '1' os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
login(token="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") login(token="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") # Zastąp swoim tokenem HF
class SourceMapper: class SourceMapper:
def __init__(self): def __init__(self):
@ -34,14 +34,19 @@ class SourceMapper:
return self.idx_to_source.get(idx, "Unknown") return self.idx_to_source.get(idx, "Unknown")
def load_file_catalog(catalog_path): def load_file_catalog(catalog_path):
try:
with open(catalog_path, 'r', encoding='utf-8') as file: with open(catalog_path, 'r', encoding='utf-8') as file:
return json.load(file) return json.load(file)
except Exception as e:
print(f"Błąd wczytywania katalogu plików: {str(e)}")
return {}
def identify_legal_document(filename, file_catalog): def identify_legal_document(filename, file_catalog):
base_name = os.path.splitext(filename)[0] base_name = os.path.splitext(filename)[0]
return file_catalog.get(base_name, "Opracowanie własne") return file_catalog.get(base_name, "Opracowanie własne")
def extract_text_from_file(file_path): def extract_text_from_file(file_path):
try:
_, ext = os.path.splitext(file_path) _, ext = os.path.splitext(file_path)
ext = ext.lower() ext = ext.lower()
@ -60,40 +65,91 @@ def extract_text_from_file(file_path):
elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']: elif ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']:
return pytesseract.image_to_string(Image.open(file_path)) return pytesseract.image_to_string(Image.open(file_path))
else: else:
print(f"Nieobsługiwany format pliku: {ext}")
return ""
except Exception as e:
print(f"Błąd ekstrakcji tekstu: {str(e)}")
return "" return ""
def prepare_dataset(directory, catalog_path, source_mapper): def prepare_dataset(directory, catalog_path, source_mapper):
file_catalog = load_file_catalog(catalog_path) file_catalog = load_file_catalog(catalog_path)
data = [] data = []
print(f"\n{'='*50}\nDIAGNOSTYKA DANYCH\n{'='*50}")
if not os.path.exists(directory):
print(f"Brak katalogu: {directory}")
return data
for root, _, files in os.walk(directory): for root, _, files in os.walk(directory):
for file in files: if not files:
file_path = os.path.join(root, file) print(f"Brak plików w katalogu: {root}")
text = extract_text_from_file(file_path)
if not text:
continue continue
for file in files:
file_path = os.path.join(root, file)
print(f"\nPrzetwarzanie pliku: {file_path}")
try:
text = extract_text_from_file(file_path)
if not text.strip():
print("Pominięto - brak tekstu")
continue
print(f"Długość tekstu: {len(text)} znaków")
doc_type = identify_legal_document(file, file_catalog) doc_type = identify_legal_document(file, file_catalog)
print(f"Rozpoznany typ dokumentu: {doc_type}")
if doc_type != "Opracowanie własne": if doc_type != "Opracowanie własne":
articles = re.split(r'(#+\s*Art\.\s*\d+[\.\s]?)', text) articles = re.split(r'(?i)(#+\s*art\.?\s*\d+[\.\s]?)', text)
for i in range(1, len(articles), 2): print(f"Znaleziono {len(articles)} fragmentów")
article_number = re.sub(r'#+\s*', '', articles[i].strip())
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
source = f"{doc_type}, {article_number}"
source_mapper.add_source(source)
if len(articles) < 2:
print("Brak artykułów w dokumencie prawnym!")
continue
for i in range(1, len(articles), 2):
article_number = re.sub(r'#+\s*', '', articles[i].strip(), flags=re.IGNORECASE)
article_content = articles[i+1].strip() if i+1 < len(articles) else ""
if not article_content:
print(f"Pominięto pusty artykuł: {article_number}")
continue
source = f"{doc_type}, {article_number}"
print(f"Dodano artykuł: {source}")
source_mapper.add_source(source)
data.append({ data.append({
"text": f"{article_number} {article_content}", "text": f"{article_number} {article_content}",
"source_idx": source_mapper.get_idx(source) "source_idx": source_mapper.get_idx(source)
}) })
else: else:
chunks = [text[i:i+512] for i in range(0, len(text), 512)] print("Traktowanie jako opracowanie własne")
clean_text = re.sub(r'\s+', ' ', text).strip()
chunks = [clean_text[i:i+512] for i in range(0, len(clean_text), 512)]
chunks = [c for c in chunks if c.strip()]
for chunk in chunks: for chunk in chunks:
data.append({ data.append({
"text": chunk, "text": chunk,
"source_idx": -1 "source_idx": -1
}) })
print(f"Dodano {len(chunks)} chunków")
except Exception as e:
print(f"Błąd podczas przetwarzania pliku: {str(e)}")
continue
print(f"\nPodsumowanie przygotowania danych:")
print(f"Łączna liczba przykładów: {len(data)}")
if data:
print("Przykładowy wpis:")
print(json.dumps(data[0], indent=2, ensure_ascii=False))
else:
print("BRAK DANYCH - sprawdź diagnostykę powyżej")
return data return data
class CustomModel(nn.Module): class CustomModel(nn.Module):
@ -104,8 +160,8 @@ class CustomModel(nn.Module):
def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs): def forward(self, input_ids=None, attention_mask=None, labels=None, source_idx=None, **kwargs):
if source_idx is not None: if source_idx is not None:
source_idx = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1) valid_indices = torch.clamp(source_idx, 0, self.source_embedding.num_embeddings-1)
source_embeds = self.source_embedding(source_idx).unsqueeze(1).expand(-1, input_ids.size(1), -1) source_embeds = self.source_embedding(valid_indices).unsqueeze(1).expand(-1, input_ids.size(1), -1)
inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds inputs_embeds = self.base_model.get_input_embeddings()(input_ids) + source_embeds
return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs) return self.base_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, **kwargs)
return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs) return self.base_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
@ -130,6 +186,11 @@ def main():
# Przygotowanie danych # Przygotowanie danych
catalog_path = "file_catalog.json" catalog_path = "file_catalog.json"
data = prepare_dataset("files", catalog_path, source_mapper) data = prepare_dataset("files", catalog_path, source_mapper)
if not data:
print("\nBrak danych do treningu! Sprawdź pliki w katalogu 'files' i diagnostykę powyżej.")
return
dataset = Dataset.from_list(data) dataset = Dataset.from_list(data)
def tokenize_function(examples): def tokenize_function(examples):
@ -141,13 +202,13 @@ def main():
return_tensors="pt" return_tensors="pt"
) )
return { return {
"input_ids": tokenized["input_ids"], "input_ids": tokenized["input_ids"][0],
"attention_mask": tokenized["attention_mask"], "attention_mask": tokenized["attention_mask"][0],
"labels": tokenized["input_ids"].clone(), "labels": tokenized["input_ids"][0].clone(),
"source_idx": examples["source_idx"] "source_idx": examples["source_idx"]
} }
tokenized_dataset = dataset.map(tokenize_function, batched=True, batch_size=8) tokenized_dataset = dataset.map(tokenize_function, batched=False)
def custom_collate_fn(features): def custom_collate_fn(features):
return { return {
@ -166,16 +227,15 @@ def main():
# Trening # Trening
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir="./results", output_dir="./results",
num_train_epochs=5, num_train_epochs=3,
per_device_train_batch_size=2, per_device_train_batch_size=2,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
learning_rate=3e-5, learning_rate=2e-5,
fp16=torch.cuda.is_available(), fp16=torch.cuda.is_available(),
logging_steps=10, logging_steps=10,
save_strategy="steps", save_strategy="steps",
save_steps=1000, save_steps=500,
report_to="none", report_to="none",
weight_decay=0.01,
remove_unused_columns=False remove_unused_columns=False
) )
@ -185,13 +245,16 @@ def main():
train_dataset=tokenized_dataset, train_dataset=tokenized_dataset,
data_collator=custom_collate_fn data_collator=custom_collate_fn
) )
print("Rozpoczęcie treningu...") print("\nRozpoczęcie treningu...")
trainer.train() trainer.train()
# Testowanie # Testowanie
def generate_answer(question): def generate_answer(question):
model.eval()
prompt = f"[PYTANIE PRAWNE] {question}"
inputs = tokenizer( inputs = tokenizer(
f"[PYTANIE PRAWNE] {question}", prompt,
return_tensors="pt", return_tensors="pt",
truncation=True, truncation=True,
max_length=512 max_length=512
@ -210,13 +273,13 @@ def main():
) )
answer = tokenizer.decode(outputs[0], skip_special_tokens=True) answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = answer.split("[PYTANIE PRAWNE]")[-1].strip() answer = answer.replace(prompt, "").strip()
sources = set() sources = set()
for match in re.finditer(r'Art\.\s*\d+', answer): for match in re.finditer(r'(?i)art\.?\s*\d+', answer):
article_ref = match.group(0).strip() article_ref = match.group(0).strip()
for idx, source in source_mapper.idx_to_source.items(): for source in source_mapper.idx_to_source.values():
if article_ref in source: if article_ref.lower() in source.lower():
sources.add(source) sources.add(source)
return { return {