import numpy as np
import torch
from collections import defaultdict as ddict
from copy import deepcopy
def monotonic(vec1, vec2, vec3, ret="bool"):
# check if vec3 [interpolation] is monotonic w.r.t. vec1 [baseline]
# and vec2 [input]
# i.e., vec3 should lie between vec1 and vec2 (for both +ve and -ve cases)
increasing_dims = vec1 > vec2 # dims where baseline > input
decreasing_dims = vec1 < vec2 # dims where baseline < input
equal_dims = vec1 == vec2 # dims where baseline == input
vec3_greater_vec1 = vec3 >= vec1
vec3_greater_vec2 = vec3 >= vec2
vec3_lesser_vec1 = vec3 <= vec1
vec3_lesser_vec2 = vec3 <= vec2
vec3_equal_vec1 = vec3 == vec1
vec3_equal_vec2 = vec3 == vec2
# if, for some dim: vec1 > vec2 then vec1 >= vec3 >= vec2
# elif: vec1 < vec2 then vec1 <= vec3 <= vec2
# elif: vec1 == vec2 then vec1 == vec3 == vec2
valid = (
increasing_dims * vec3_lesser_vec1 * vec3_greater_vec2
+ decreasing_dims * vec3_greater_vec1 * vec3_lesser_vec2
+ equal_dims * vec3_equal_vec1 * vec3_equal_vec2
)
if ret == "bool":
return valid.sum() == vec1.shape[0]
elif ret == "count":
return valid.sum()
elif ret == "vec":
return valid
def make_monotonic_vec(vec1, vec2, vec3, steps):
# create a new vec4 from vec3 [anchor] which is monotonic w.r.t. vec1
# [baseline] and vec2 [input]
mono_dims = monotonic(vec1, vec2, vec3, ret="vec")
non_mono_dims = ~mono_dims
if non_mono_dims.sum() == 0:
return vec3
# make vec4 monotonic
vec4 = deepcopy(vec3)
vec4[non_mono_dims] = vec2[non_mono_dims] - (1.0 / steps) * (
vec2[non_mono_dims] - vec1[non_mono_dims]
)
return vec4
def distance(A, B):
# return euclidean distance between two points
return np.sqrt(np.sum((A - B) ** 2))
def find_next_wrd(
wrd_idx,
ref_idx,
word_path,
word_features,
adj,
strategy="greedy",
steps=30,
):
if wrd_idx == ref_idx:
# If (for some reason) we do select the ref_idx as the previous
# anchor word, then all further anchor words should be ref_idx
return ref_idx
anchor_map = ddict(list)
cx = adj[wrd_idx].tocoo()
for j, v in zip(cx.col, cx.data):
# we should not consider the anchor word to be the ref_idx
# [baseline] unless forced to.
if j == ref_idx:
continue
if strategy == "greedy":
# calculate the distance of the monotonized vec from the
# anchor point
monotonic_vec = make_monotonic_vec(
word_features[ref_idx],
word_features[wrd_idx],
word_features[j],
steps,
)
anchor_map[j] = [distance(word_features[j], monotonic_vec)]
elif strategy == "maxcount":
# count the number of non-monotonic dimensions (10000 is
# an arbitrarily high and is a hack to be agnostic of
# word_features dimension)
non_mono_count = 10000 - monotonic(
word_features[ref_idx],
word_features[wrd_idx],
word_features[j],
ret="count",
)
anchor_map[j] = [non_mono_count]
elif strategy == "non_monotonic":
# Here we just use the distance between the reference
# and the proposed word
anchor_map[j] = [
distance(word_features[ref_idx], word_features[j])
]
else:
raise NotImplementedError
if len(anchor_map) == 0:
return ref_idx
sorted_dist_map = {
k: v
for k, v in sorted(anchor_map.items(), key=lambda item: item[1][0])
}
# remove words that are already selected in the path
for key in word_path:
sorted_dist_map.pop(key, None)
if len(sorted_dist_map) == 0:
return ref_idx
# return the top key
return next(iter(sorted_dist_map))
def find_word_path(
wrd_idx,
ref_idx,
word_idx_map,
word_features,
adj,
steps=30,
strategy="greedy",
):
# if wrd_idx is CLS or SEP then just copy that and return
if ("[CLS]" in word_idx_map and wrd_idx == word_idx_map["[CLS]"]) or (
"[SEP]" in word_idx_map and wrd_idx == word_idx_map["[SEP]"]
):
return [wrd_idx] * (steps + 1)
word_path = [wrd_idx]
last_idx = wrd_idx
for step in range(steps):
next_idx = find_next_wrd(
last_idx,
ref_idx,
word_path,
word_features=word_features,
adj=adj,
strategy=strategy,
steps=steps,
)
word_path.append(next_idx)
last_idx = next_idx
return word_path
def upscale(embs):
# add a average embedding between each consecutive vec in embs
embs = np.array(embs)
avg_embs = 0.5 * (embs[0:-1] + embs[1:])
final_embs = np.empty(
(embs.shape[0] + avg_embs.shape[0], embs.shape[1]), dtype=embs.dtype
)
final_embs[::2] = embs
final_embs[1::2] = avg_embs
return final_embs
def make_monotonic_path(
word_path_ids,
ref_idx,
word_features,
steps=30,
factor=0,
):
monotonic_embs = [word_features[word_path_ids[0]]]
vec1 = word_features[ref_idx]
for idx in range(len(word_path_ids) - 1):
vec2 = monotonic_embs[-1]
vec3 = word_features[word_path_ids[idx + 1]]
vec4 = make_monotonic_vec(vec1, vec2, vec3, steps)
monotonic_embs.append(vec4)
monotonic_embs.append(vec1)
# reverse the list so that baseline is the first and input word is the last
monotonic_embs.reverse()
final_embs = monotonic_embs
# do upscaling for factor number of times
for _ in range(factor):
final_embs = upscale(final_embs)
# verify monotonicity
check = True
for i in range(len(final_embs) - 1):
check *= monotonic(
final_embs[-1], final_embs[i], final_embs[i + 1], ret="bool"
)
assert check
return final_embs