Skip to content

Fix for wrong dimension and token alignment #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions tfsdg/pipelines/tfsdg_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import logging
from dataclasses import dataclass
from typing import List, Optional, Union
from typing import List, Optional, Union, Dict

import numpy as np
import stanza
Expand Down Expand Up @@ -73,7 +73,15 @@ def __init__(
def preprocess_prompt(self, prompt: str) -> str:
return prompt.lower().strip().strip(".").strip()

def get_sub_nps(self, tree: Tree, left: int, right: int) -> List[SubNP]:
def get_sub_nps(
self,
tree: Tree,
full_sent: str,
left: int,
right: int,
idx_map: Dict[int, List[int]],
highest_only: bool = False,
) -> List[SubNP]:

if isinstance(tree, str) or len(tree.leaves()) == 1:
return []
Expand All @@ -88,25 +96,81 @@ def get_sub_nps(self, tree: Tree, left: int, right: int) -> List[SubNP]:
if tree.label() == "NP" and n_leaves > 1:
sub_np = SubNP(
text=" ".join(tree.leaves()),
span=Span(left=int(left), right=int(right)),
span=Span(left=int(min(idx_map[left])), right=int(min(idx_map[right]))),
)
sub_nps.append(sub_np)

if highest_only and sub_nps[-1].text != full_sent:
return sub_nps

for i, subtree in enumerate(tree):
sub_nps += self.get_sub_nps(
subtree,
full_sent,
left=left + offset[i],
right=left + offset[i] + n_subtree_leaves[i],
idx_map=idx_map,
)
return sub_nps

def get_all_nps(self, tree: Tree, full_sent: Optional[str] = None) -> AllNPs:
def get_token_alignment_map(
self, tree: Tree, tokens: Optional[List[str]]
) -> Dict[int, List[int]]:
if tokens is None:
return {i: [i] for i in range(len(tree.leaves()) + 1)}

def _get_token(token: str):
return token[:-4] if token.endswith("</w>") else token

idx_map: Dict[int, List[int]] = {}
j = 0
max_offset = abs(len(tokens) - len(tree.leaves()))
tree_prev_leaf = ""
for i, w in enumerate(tree.leaves()):
token = _get_token(tokens[j])
idx_map[i] = [j]
if token == tree_prev_leaf + w:
tree_prev_leaf = ""
j += 1
else:
if len(token) < len(w):
prev = ""
while prev + token != w:
prev += token
j += 1
token = _get_token(tokens[j])
idx_map[i].append(j)
assert j - i <= max_offset
else:
tree_prev_leaf += w
j -= 1
j += 1
idx_map[i + 1] = [j]
return idx_map

def get_all_nps(
self,
tree: Tree,
full_sent: str,
tokens: Optional[List[str]] = None,
highest_only: bool = False,
lowest_only: bool = False,
) -> AllNPs:
start = 0
end = len(tree.leaves())

all_sub_nps = self.get_sub_nps(tree, left=start, right=end)
idx_map = self.get_token_alignment_map(tree=tree, tokens=tokens)

lowest_nps = []
all_sub_nps = self.get_sub_nps(
tree,
full_sent,
left=start,
right=end,
idx_map=idx_map,
highest_only=highest_only,
)

lowest_nps: List[SubNP] = []
for i in range(len(all_sub_nps)):
span = all_sub_nps[i].span
lowest = True
Expand All @@ -118,6 +182,9 @@ def get_all_nps(self, tree: Tree, full_sent: Optional[str] = None) -> AllNPs:
if lowest:
lowest_nps.append(all_sub_nps[i])

if lowest_only:
all_nps = [lowest_np.text for lowest_np in lowest_nps]

all_nps = [all_sub_np.text for all_sub_np in all_sub_nps]
spans = [all_sub_np.span for all_sub_np in all_sub_nps]

Expand Down Expand Up @@ -201,14 +268,14 @@ def _align_sequence(
start, end = span.left + 1, span.right + 1
seg_length = end - start

full_seq[start:end] = seq[1 : 1 + seg_length]
full_seq[:, start:end] = seq[:, 1 : 1 + seg_length]
if zero_out:
full_seq[1:start] = 0
full_seq[end:eos_loc] = 0
full_seq[:, 1:start] = 0
full_seq[:, end:eos_loc] = 0

if replace_pad:
pad_length = len(full_seq) - eos_loc
full_seq[eos_loc:] = seq[1 + seg_length : 1 + seg_length + pad_length]
full_seq[:, eos_loc:] = seq[:, 1 + seg_length : 1 + seg_length + pad_length]

# shape: (768, 77) -> (77, 768)
return full_seq.transpose(0, dim)
Expand Down Expand Up @@ -318,8 +385,11 @@ def __call__(

doc = self.nlp(preprocessed_prompt)
tree = Tree.fromstring(str(doc.sentences[0].constituency))
tokens = self.tokenizer.tokenize(preprocessed_prompt)

all_nps = self.get_all_nps(tree=tree, full_sent=preprocessed_prompt)
all_nps = self.get_all_nps(
tree=tree, full_sent=preprocessed_prompt, tokens=tokens
)
cond_embeddings = self.apply_text_encoder(
struct_attention=struct_attention,
prompt=preprocessed_prompt,
Expand Down