64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
import csv
|
|
import json
|
|
from math import sqrt
|
|
import re
|
|
|
|
WORD_LENGTH_THRESHOLD = 4
|
|
|
|
class RowDecoder(json.JSONDecoder):
|
|
def decode(self, s):
|
|
db = json.JSONDecoder.decode(self, s)
|
|
return [{**obj, 'index': set(obj['index'])} for obj in db]
|
|
|
|
class RowEncoder(json.JSONEncoder):
|
|
def default(self, obj):
|
|
if isinstance(obj, set):
|
|
return list(obj)
|
|
else:
|
|
return json.JSONEncoder.default(obj)
|
|
|
|
def keepOnlyAlphaChars(word):
|
|
return ''.join([c for c in word if c.isalpha()])
|
|
|
|
def index(text):
|
|
words = re.split('\s', text)
|
|
normalized_words = [keepOnlyAlphaChars(word).lower() for word in words]
|
|
important_words = set([w for w in normalized_words
|
|
if len(w) >= WORD_LENGTH_THRESHOLD])
|
|
return important_words
|
|
|
|
def insert(db, row):
|
|
db.append({'title': row[0], 'quote': row[1], 'index': index(row[1])})
|
|
|
|
def build_db(inputCSV):
|
|
db = []
|
|
with open(inputCSV, 'r') as file:
|
|
csv_reader = csv.reader(file, delimiter=',')
|
|
data = False
|
|
for row in csv_reader:
|
|
if data:
|
|
insert(db, row)
|
|
else:
|
|
data = True
|
|
return db
|
|
|
|
def save_db(db, outputJSON):
|
|
with open(outputJSON, 'w') as file:
|
|
json.dump(db, file, cls=RowEncoder)
|
|
|
|
def open_db(filePath):
|
|
with open(filePath, 'r') as file:
|
|
return json.load(file, cls=RowDecoder)
|
|
|
|
"""
|
|
We define a similarity measure on sets which counts the number of elements
|
|
they have in common
|
|
"""
|
|
def scalar(a, b):
|
|
return len(a.intersection(b))/sqrt(len(a)*len(b))
|
|
|
|
def find_best_quote(db, user_input):
|
|
indexed_input = index(user_input)
|
|
max_score = None
|
|
for entry in db:
|
|
score = scalar(indexed_input, entry
|