Skip to content

The implementation is wrong #11

Open
@elvisnava

Description

@elvisnava

I would advise anyone against using this implementation until these issues are fixed.

In the function for sequence alignment (but the same can be said about _expand_sequence), we have:

    def _align_sequence(
            self,
            full_seq: torch.Tensor,
            seq: torch.Tensor,
            span: Span,
            eos_loc: int,
            dim: int = 1,
            zero_out: bool = False,
            replace_pad: bool = False,
    ) -> torch.Tensor:

    # shape: (77, 768) -> (768, 77)
    seq = seq.transpose(0, dim)

    # shape: (77, 768) -> (768, 77)
    full_seq = full_seq.transpose(0, dim)

    start, end = span.left + 1, span.right + 1
    seg_length = end - start

    full_seq[start:end] = seq[1 : 1 + seg_length]
    if zero_out:
        full_seq[1:start] = 0
        full_seq[end:eos_loc] = 0

    if replace_pad:
        pad_length = len(full_seq) - eos_loc
        full_seq[eos_loc:] = seq[1 + seg_length : 1 + seg_length + pad_length]

    # shape: (768, 77) -> (77, 768)
    return full_seq.transpose(0, dim)

which is supposed to replace embeddings in full_seq (77,768) between start and end with the ones from seq. However, a transpose operation is first performed, making full_seq have a shape of (768,77), which makes the assignment full_seq[start:end] be over the wrong dimension. Similarly, seq is also addressed wrongly.

Moreover, I believe the calculation of spans to also be incorrect, as it considers words without considering the possibility of a word being broken into multiple tokens. In the repository of the paper author, this function

def get_token_alignment_map(tree, tokens):
 if tokens is None:
     return {i:[i] for i in range(len(tree.leaves())+1)}
     
 def get_token(token):
     return token[:-4] if token.endswith("</w>") else token

 idx_map = {}
 j = 0
 max_offset = np.abs(len(tokens) - len(tree.leaves()))
 mytree_prev_leaf = ""
 for i, w in enumerate(tree.leaves()):
     token = get_token(tokens[j])
     idx_map[i] = [j]
     if token == mytree_prev_leaf+w:
         mytree_prev_leaf = ""
         j += 1
     else:
         if len(token) < len(w):
             prev = ""
             while prev + token != w:
                 prev += token
                 j += 1
                 token = get_token(tokens[j])
                 idx_map[i].append(j)
                 # assert j - i <= max_offset
         else:
             mytree_prev_leaf += w
             j -= 1
         j += 1
 idx_map[i+1] = [j]
 return idx_map

is used to perform this mapping between word spans and token spans.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions