Skip to content

Commit 3490be6

Browse files
committed
add search code
1 parent fee20eb commit 3490be6

File tree

6 files changed

+177
-26
lines changed

6 files changed

+177
-26
lines changed

code/eda.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
data = pd.read_json(out_name)
1111

12+
data = data.sample(n=10000)
13+
1214
img_repr = data['image_repr'].tolist()
1315
img_repr_random = data['image_repr'].tolist()
1416
shuffle(img_repr_random)

code/model_triplet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,14 @@ def cap_sequences(list_sequences, max_len, append):
9898
return capped
9999

100100

101-
def read_img(path):
101+
def read_img(path, preprocess=True):
102102
img = cv2.imread(path)
103103
if img is None or img.size<10:
104104
img = np.zeros((222, 171))
105105
img = cv2.resize(img, (171, 222))
106-
return preprocess_input(img)
106+
if preprocess:
107+
img = preprocess_input(img)
108+
return img
107109

108110

109111
def gen(list_images, list_captions, batch_size=16, aug=False):

code/predict_model_triplet.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,37 +31,40 @@ def chunker(seq, size):
3131
t_model.load_weights(file_path, by_name=True)
3232
i_model.load_weights(file_path, by_name=True)
3333

34-
target_image_encoding = []
35-
36-
for img_paths in tqdm(chunker(list_images_test, 128), total=len(list_images_test)//128):
37-
images = np.array([read_img(file_path) for file_path in img_paths])
38-
e = i_model.predict(images)
39-
target_image_encoding += e.tolist()
40-
41-
target_text_encoding = t_model.predict(np.array(captions_test), verbose=1, batch_size=128)
42-
43-
target_text_encoding = target_text_encoding.tolist()
44-
45-
df = pd.DataFrame({"images": list_images_test, "text": _captions_test, "image_repr": target_image_encoding,
46-
"text_repr": target_text_encoding})
47-
48-
df.to_json(out_name, orient='records')
49-
50-
data = json.load(open(out_name, 'r'))
51-
json.dump(data, open(out_name, 'w'), indent=4)
52-
53-
# New queries
54-
55-
out_name = "../output/queries_representations.json"
56-
57-
_captions_test = ['blue shirt', 'red dress', 'halloween outfit', 'baggy jeans', 'pokemon']
34+
# target_image_encoding = []
35+
#
36+
# for img_paths in tqdm(chunker(list_images_test, 128), total=len(list_images_test)//128):
37+
# images = np.array([read_img(file_path) for file_path in img_paths])
38+
# e = i_model.predict(images)
39+
# target_image_encoding += e.tolist()
40+
#
41+
# target_text_encoding = t_model.predict(np.array(captions_test), verbose=1, batch_size=128)
42+
#
43+
# target_text_encoding = target_text_encoding.tolist()
44+
#
45+
# df = pd.DataFrame({"images": list_images_test, "text": _captions_test, "image_repr": target_image_encoding,
46+
# "text_repr": target_text_encoding})
47+
#
48+
# df.to_json(out_name, orient='records')
49+
#
50+
# data = json.load(open(out_name, 'r'))
51+
# json.dump(data, open(out_name, 'w'), indent=4)
52+
#
53+
# # New queries
54+
#
55+
# out_name = "../output/queries_representations.json"
56+
57+
_captions_test = ['blue tshirt', 'blue shirt', 'red dress', 'halloween outfit', 'baggy jeans', 'ring',
58+
'Black trousers', 'heart Pendant']
5859

5960
captions_test = [tokenize(x) for x in _captions_test]
6061
captions_test = map_sentences(captions_test, mapping)
6162
captions_test = cap_sequences(captions_test, 70, 0)
6263

6364
target_text_encoding = t_model.predict(np.array(captions_test), verbose=1, batch_size=128)
6465

66+
target_text_encoding = target_text_encoding.tolist()
67+
6568
df = pd.DataFrame({"text": _captions_test,
6669
"text_repr": target_text_encoding})
6770

code/search.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import json
2+
from random import shuffle
3+
import pandas as pd
4+
import numpy as np
5+
from matplotlib import pyplot
6+
from sklearn.neighbors import NearestNeighbors
7+
8+
repr_json = "../output/test_representations.json"
9+
10+
data = pd.read_json(repr_json)
11+
12+
data = data.sample(n=1000)
13+
14+
img_repr = data['image_repr'].tolist()
15+
text_repr = data['text_repr'].tolist()
16+
17+
nn = NearestNeighbors(n_jobs=-1, n_neighbors=1000)
18+
19+
nn.fit(text_repr)
20+
21+
preds = nn.kneighbors(img_repr, return_distance=False).tolist()
22+
ranks = []
23+
24+
for i, x in enumerate(preds):
25+
rank = x.index(i)+1
26+
ranks.append(rank)
27+
28+
print("Average rank :", np.mean(ranks))

code/search_by_image.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import json
2+
from random import shuffle
3+
import pandas as pd
4+
import numpy as np
5+
from matplotlib import pyplot
6+
from sklearn.neighbors import NearestNeighbors
7+
import matplotlib.pyplot as plt
8+
from model_triplet import read_img
9+
import cv2
10+
from uuid import uuid4
11+
12+
13+
repr_json = "../output/test_representations.json"
14+
15+
data = pd.read_json(repr_json)
16+
17+
data = data.sample(n=50000)
18+
19+
img_repr = data['image_repr'].tolist()
20+
img_paths = data['images'].tolist()
21+
text_repr = data['text_repr'].tolist()
22+
23+
nn = NearestNeighbors(n_jobs=-1, n_neighbors=9)
24+
25+
nn.fit(img_repr)
26+
27+
preds = nn.kneighbors(img_repr[:100], return_distance=False).tolist()
28+
29+
most_similar_images = []
30+
query_image = []
31+
32+
33+
for i, x in enumerate(preds):
34+
preds_paths = [img_paths[i] for i in x]
35+
query_image.append(preds_paths[0])
36+
most_similar_images.append(preds_paths[1:])
37+
38+
for q, similar in zip(query_image, most_similar_images):
39+
fig, axes = plt.subplots(3, 3)
40+
all_images = [q]+similar
41+
42+
for idx, img_path in enumerate(all_images):
43+
i = idx % 3 # Get subplot row
44+
j = idx // 3 # Get subplot column
45+
image = read_img(img_path, preprocess=False)
46+
image = image[:, :, ::-1]
47+
axes[i, j].imshow(image/255)
48+
axes[i, j].axis('off')
49+
axes[i, j].axis('off')
50+
if idx == 0:
51+
axes[i, j].set_title('Query Image')
52+
else:
53+
axes[i, j].set_title('Result Image %s'%i)
54+
55+
plt.subplots_adjust(wspace=0.2, hspace=0.2)
56+
plt.savefig('../output/images/%s.png'%uuid4().hex)
57+

code/search_by_keywords.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import json
2+
from random import shuffle
3+
import pandas as pd
4+
import numpy as np
5+
from matplotlib import pyplot
6+
from sklearn.neighbors import NearestNeighbors
7+
import matplotlib.pyplot as plt
8+
from model_triplet import read_img
9+
import cv2
10+
from uuid import uuid4
11+
12+
13+
repr_json = "../output/test_representations.json"
14+
15+
data = pd.read_json(repr_json)
16+
17+
queries_repr_json = "../output/queries_representations.json"
18+
19+
queries_data = pd.read_json(queries_repr_json)
20+
21+
data = data.sample(n=50000)
22+
23+
img_repr = data['image_repr'].tolist()
24+
img_paths = data['images'].tolist()
25+
text_repr = queries_data['text_repr'].tolist()
26+
27+
nn = NearestNeighbors(n_jobs=-1, n_neighbors=9)
28+
29+
nn.fit(img_repr)
30+
31+
preds = nn.kneighbors(text_repr, return_distance=False).tolist()
32+
33+
most_similar_images = []
34+
query_image = []
35+
36+
37+
for i, x in enumerate(preds):
38+
preds_paths = [img_paths[i] for i in x]
39+
most_similar_images.append(preds_paths)
40+
41+
for q, all_images in zip(queries_data['text'], most_similar_images):
42+
fig, axes = plt.subplots(3, 3)
43+
44+
for idx, img_path in enumerate(all_images):
45+
i = idx % 3 # Get subplot row
46+
j = idx // 3 # Get subplot column
47+
image = read_img(img_path, preprocess=False)
48+
image = image[:, :, ::-1]
49+
axes[i, j].imshow(image/255)
50+
axes[i, j].axis('off')
51+
axes[i, j].axis('off')
52+
axes[i, j].set_title('Result Image %s'%i)
53+
54+
55+
plt.subplots_adjust(wspace=0.2, hspace=0.2)
56+
fig.title('Query : %s'%q)
57+
58+
plt.savefig('../output/queries/%s.png'%uuid4().hex)
59+

0 commit comments

Comments
 (0)