base_search_fuzzy: black, isort

pull/2530/head
Ernesto Tejeda 2020-03-25 17:09:14 -04:00 committed by Daniel Reis
parent d36af2cc80
commit 37fb022bee
5 changed files with 153 additions and 155 deletions

View File

@ -2,20 +2,17 @@
# Copyright 2016 Serpent Consulting Services Pvt. Ltd. # Copyright 2016 Serpent Consulting Services Pvt. Ltd.
# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl).
{ {
'name': "Fuzzy Search", "name": "Fuzzy Search",
'summary': "Fuzzy search with the PostgreSQL trigram extension", "summary": "Fuzzy search with the PostgreSQL trigram extension",
'category': 'Uncategorized', "category": "Uncategorized",
'version': '12.0.1.0.0', "version": "12.0.1.0.0",
'website': 'https://github.com/OCA/server-tools', "website": "https://github.com/OCA/server-tools",
'author': 'bloopark systems GmbH & Co. KG, ' "author": "bloopark systems GmbH & Co. KG, "
'Eficent, ' "Eficent, "
'Serpent CS, ' "Serpent CS, "
'Odoo Community Association (OCA)', "Odoo Community Association (OCA)",
'license': 'AGPL-3', "license": "AGPL-3",
'depends': ['base'], "depends": ["base"],
'data': [ "data": ["views/trgm_index.xml", "security/ir.model.access.csv",],
'views/trgm_index.xml', "installable": True,
'security/ir.model.access.csv',
],
'installable': True,
} }

View File

@ -7,7 +7,6 @@ import logging
from odoo import _, api, models from odoo import _, api, models
from odoo.osv import expression from odoo.osv import expression
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -18,26 +17,23 @@ def patch_leaf_trgm(method):
left, operator, right = leaf left, operator, right = leaf
table_alias = '"%s"' % (eleaf.generate_alias()) table_alias = '"%s"' % (eleaf.generate_alias())
if operator == '%': if operator == "%":
sql_operator = '%%' sql_operator = "%%"
params = [] params = []
if left in model._fields: if left in model._fields:
column = '%s.%s' % (table_alias, expression._quote(left)) column = "{}.{}".format(table_alias, expression._quote(left))
query = '(%s %s %s)' % ( query = "({} {} {})".format(
column, column,
sql_operator, sql_operator,
model._fields[left].column_format, model._fields[left].column_format,
) )
elif left in models.MAGIC_COLUMNS: elif left in models.MAGIC_COLUMNS:
query = "(%s.\"%s\" %s %%s)" % ( query = '({}."{}" {} %s)'.format(table_alias, left, sql_operator)
table_alias, left, sql_operator)
params = right params = right
else: # Must not happen else: # Must not happen
raise ValueError(_( raise ValueError(_("Invalid field {!r} in domain term {!r}".format(left, leaf)))
"Invalid field %r in domain term %r" % (left, leaf)
))
if left in model._fields: if left in model._fields:
params = str(right) params = str(right)
@ -45,8 +41,8 @@ def patch_leaf_trgm(method):
if isinstance(params, str): if isinstance(params, str):
params = [params] params = [params]
return query, params return query, params
elif operator == 'inselect': elif operator == "inselect":
right = (right[0].replace(' % ', ' %% '), right[1]) right = (right[0].replace(" % ", " %% "), right[1])
eleaf.leaf = (left, operator, right) eleaf.leaf = (left, operator, right)
return method(self, eleaf) return method(self, eleaf)
@ -56,11 +52,10 @@ def patch_leaf_trgm(method):
def patch_generate_order_by(method): def patch_generate_order_by(method):
@api.model @api.model
def decorate_generate_order_by(self, order_spec, query): def decorate_generate_order_by(self, order_spec, query):
if order_spec and order_spec.startswith('similarity('): if order_spec and order_spec.startswith("similarity("):
return ' ORDER BY ' + order_spec return " ORDER BY " + order_spec
return method(self, order_spec, query) return method(self, order_spec, query)
decorate_generate_order_by.__decorated__ = True decorate_generate_order_by.__decorated__ = True
@ -70,22 +65,22 @@ def patch_generate_order_by(method):
class IrModel(models.Model): class IrModel(models.Model):
_inherit = 'ir.model' _inherit = "ir.model"
@api.model_cr @api.model_cr
def _register_hook(self): def _register_hook(self):
# We have to prevent wrapping the function twice to avoid recursion # We have to prevent wrapping the function twice to avoid recursion
# errors # errors
if not hasattr(expression.expression._expression__leaf_to_sql, if not hasattr(expression.expression._expression__leaf_to_sql, "__decorated__"):
'__decorated__'):
expression.expression._expression__leaf_to_sql = patch_leaf_trgm( expression.expression._expression__leaf_to_sql = patch_leaf_trgm(
expression.expression._expression__leaf_to_sql) expression.expression._expression__leaf_to_sql
)
if '%' not in expression.TERM_OPERATORS: if "%" not in expression.TERM_OPERATORS:
expression.TERM_OPERATORS += ('%',) expression.TERM_OPERATORS += ("%",)
if not hasattr(models.BaseModel._generate_order_by, if not hasattr(models.BaseModel._generate_order_by, "__decorated__"):
'__decorated__'):
models.BaseModel._generate_order_by = patch_generate_order_by( models.BaseModel._generate_order_by = patch_generate_order_by(
models.BaseModel._generate_order_by) models.BaseModel._generate_order_by
)
return super(IrModel, self)._register_hook() return super(IrModel, self)._register_hook()

View File

@ -4,10 +4,10 @@
# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl).
import logging import logging
from odoo import _, api, exceptions, fields, models
from psycopg2.extensions import AsIs from psycopg2.extensions import AsIs
from odoo import _, api, exceptions, fields, models
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -15,76 +15,81 @@ class TrgmIndex(models.Model):
"""Model for Trigram Index.""" """Model for Trigram Index."""
_name = 'trgm.index' _name = "trgm.index"
_rec_name = 'field_id' _rec_name = "field_id"
_description = 'Trigram Index' _description = "Trigram Index"
field_id = fields.Many2one( field_id = fields.Many2one(
comodel_name='ir.model.fields', comodel_name="ir.model.fields",
string='Field', string="Field",
required=True, required=True,
help='You can either select a field of type "text" or "char".' help='You can either select a field of type "text" or "char".',
) )
index_name = fields.Char( index_name = fields.Char(
string='Index Name', string="Index Name",
readonly=True, readonly=True,
help='The index name is automatically generated like ' help="The index name is automatically generated like "
'fieldname_indextype_idx. If the index already exists and the ' "fieldname_indextype_idx. If the index already exists and the "
'index is located in the same table then this index is reused. ' "index is located in the same table then this index is reused. "
'If the index is located in another table then a number is added ' "If the index is located in another table then a number is added "
'at the end of the index name.' "at the end of the index name.",
) )
index_type = fields.Selection( index_type = fields.Selection(
selection=[('gin', 'GIN'), ('gist', 'GiST')], selection=[("gin", "GIN"), ("gist", "GiST")],
string='Index Type', string="Index Type",
default='gin', default="gin",
required=True, required=True,
help='Cite from PostgreSQL documentation: "As a rule of thumb, a ' help='Cite from PostgreSQL documentation: "As a rule of thumb, a '
'GIN index is faster to search than a GiST index, but slower to ' "GIN index is faster to search than a GiST index, but slower to "
'build or update; so GIN is better suited for static data and ' "build or update; so GIN is better suited for static data and "
'GiST for often-updated data."' 'GiST for often-updated data."',
) )
@api.model_cr @api.model_cr
def _trgm_extension_exists(self): def _trgm_extension_exists(self):
self.env.cr.execute(""" self.env.cr.execute(
"""
SELECT name, installed_version SELECT name, installed_version
FROM pg_available_extensions FROM pg_available_extensions
WHERE name = 'pg_trgm' WHERE name = 'pg_trgm'
LIMIT 1; LIMIT 1;
""") """
)
extension = self.env.cr.fetchone() extension = self.env.cr.fetchone()
if extension is None: if extension is None:
return 'missing' return "missing"
if extension[1] is None: if extension[1] is None:
return 'uninstalled' return "uninstalled"
return 'installed' return "installed"
@api.model_cr @api.model_cr
def _is_postgres_superuser(self): def _is_postgres_superuser(self):
self.env.cr.execute("SHOW is_superuser;") self.env.cr.execute("SHOW is_superuser;")
superuser = self.env.cr.fetchone() superuser = self.env.cr.fetchone()
return superuser is not None and superuser[0] == 'on' or False return superuser is not None and superuser[0] == "on" or False
@api.model_cr @api.model_cr
def _install_trgm_extension(self): def _install_trgm_extension(self):
extension = self._trgm_extension_exists() extension = self._trgm_extension_exists()
if extension == 'missing': if extension == "missing":
_logger.warning('To use pg_trgm you have to install the ' _logger.warning(
'postgres-contrib module.') "To use pg_trgm you have to install the " "postgres-contrib module."
elif extension == 'uninstalled': )
elif extension == "uninstalled":
if self._is_postgres_superuser(): if self._is_postgres_superuser():
self.env.cr.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;") self.env.cr.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm;")
return True return True
else: else:
_logger.warning('To use pg_trgm you have to create the ' _logger.warning(
'extension pg_trgm in your database or you ' "To use pg_trgm you have to create the "
'have to be the superuser.') "extension pg_trgm in your database or you "
"have to be the superuser."
)
else: else:
return True return True
return False return False
@ -93,8 +98,10 @@ class TrgmIndex(models.Model):
def _auto_init(self): def _auto_init(self):
res = super(TrgmIndex, self)._auto_init() res = super(TrgmIndex, self)._auto_init()
if self._install_trgm_extension(): if self._install_trgm_extension():
_logger.info('The pg_trgm is loaded in the database and the ' _logger.info(
'fuzzy search can be used.') "The pg_trgm is loaded in the database and the "
"fuzzy search can be used."
)
return res return res
@api.model_cr @api.model_cr
@ -103,18 +110,20 @@ class TrgmIndex(models.Model):
new_index_name = index_name + str(inc) new_index_name = index_name + str(inc)
else: else:
new_index_name = index_name new_index_name = index_name
self.env.cr.execute(""" self.env.cr.execute(
"""
SELECT tablename, indexname SELECT tablename, indexname
FROM pg_indexes FROM pg_indexes
WHERE indexname = %(index)s; WHERE indexname = %(index)s;
""", {'index': new_index_name}) """,
{"index": new_index_name},
)
indexes = self.env.cr.fetchone() indexes = self.env.cr.fetchone()
if indexes is not None and indexes[0] == table_name: if indexes is not None and indexes[0] == table_name:
return True, index_name return True, index_name
elif indexes is not None: elif indexes is not None:
return self.get_not_used_index(index_name, table_name, return self.get_not_used_index(index_name, table_name, inc + 1)
inc + 1)
return False, new_index_name return False, new_index_name
@ -123,39 +132,42 @@ class TrgmIndex(models.Model):
self.ensure_one() self.ensure_one()
if not self._install_trgm_extension(): if not self._install_trgm_extension():
raise exceptions.UserError(_( raise exceptions.UserError(
'The pg_trgm extension does not exists or cannot be ' _("The pg_trgm extension does not exists or cannot be " "installed.")
'installed.')) )
table_name = self.env[self.field_id.model_id.model]._table table_name = self.env[self.field_id.model_id.model]._table
column_name = self.field_id.name column_name = self.field_id.name
index_type = self.index_type index_type = self.index_type
index_name = '%s_%s_idx' % (column_name, index_type) index_name = "{}_{}_idx".format(column_name, index_type)
index_exists, index_name = self.get_not_used_index( index_exists, index_name = self.get_not_used_index(index_name, table_name)
index_name, table_name)
if not index_exists: if not index_exists:
self.env.cr.execute(""" self.env.cr.execute(
"""
CREATE INDEX %(index)s CREATE INDEX %(index)s
ON %(table)s ON %(table)s
USING %(indextype)s (%(column)s %(indextype)s_trgm_ops); USING %(indextype)s (%(column)s %(indextype)s_trgm_ops);
""", { """,
'table': AsIs(table_name), {
'index': AsIs(index_name), "table": AsIs(table_name),
'column': AsIs(column_name), "index": AsIs(index_name),
'indextype': AsIs(index_type) "column": AsIs(column_name),
}) "indextype": AsIs(index_type),
},
)
return index_name return index_name
@api.model @api.model
def index_exists(self, model_name, field_name): def index_exists(self, model_name, field_name):
field = self.env['ir.model.fields'].search([ field = self.env["ir.model.fields"].search(
('model', '=', model_name), ('name', '=', field_name)], limit=1) [("model", "=", model_name), ("name", "=", field_name)], limit=1
)
if not field: if not field:
return False return False
trgm_index = self.search([('field_id', '=', field.id)], limit=1) trgm_index = self.search([("field_id", "=", field.id)], limit=1)
return bool(trgm_index) return bool(trgm_index)
@api.model @api.model
@ -167,9 +179,10 @@ class TrgmIndex(models.Model):
@api.multi @api.multi
def unlink(self): def unlink(self):
for rec in self: for rec in self:
self.env.cr.execute(""" self.env.cr.execute(
"""
DROP INDEX IF EXISTS %(index)s; DROP INDEX IF EXISTS %(index)s;
""", { """,
'index': AsIs(rec.index_name), {"index": AsIs(rec.index_name),},
}) )
return super(TrgmIndex, self).unlink() return super(TrgmIndex, self).unlink()

View File

@ -8,22 +8,20 @@ from odoo.tests.common import TransactionCase, at_install, post_install
@at_install(False) @at_install(False)
@post_install(True) @post_install(True)
class QueryGenerationCase(TransactionCase): class QueryGenerationCase(TransactionCase):
def setUp(self): def setUp(self):
super(QueryGenerationCase, self).setUp() super(QueryGenerationCase, self).setUp()
self.ResPartner = self.env['res.partner'] self.ResPartner = self.env["res.partner"]
self.TrgmIndex = self.env['trgm.index'] self.TrgmIndex = self.env["trgm.index"]
self.ResPartnerCategory = self.env['res.partner.category'] self.ResPartnerCategory = self.env["res.partner.category"]
def test_fuzzy_where_generation(self): def test_fuzzy_where_generation(self):
"""Check the generation of the where clause.""" """Check the generation of the where clause."""
# the added fuzzy search operator should be available in the allowed # the added fuzzy search operator should be available in the allowed
# operators # operators
self.assertIn('%', expression.TERM_OPERATORS) self.assertIn("%", expression.TERM_OPERATORS)
# create new query with fuzzy search operator # create new query with fuzzy search operator
query = self.ResPartner._where_calc( query = self.ResPartner._where_calc([("name", "%", "test")], active_test=False)
[('name', '%', 'test')], active_test=False)
from_clause, where_clause, where_clause_params = query.get_sql() from_clause, where_clause, where_clause_params = query.get_sql()
# the % parameter has to be escaped (%%) for the string replation # the % parameter has to be escaped (%%) for the string replation
@ -32,70 +30,65 @@ class QueryGenerationCase(TransactionCase):
# test the right sql query statement creation # test the right sql query statement creation
# now there should be only one '%' # now there should be only one '%'
complete_where = self.env.cr.mogrify( complete_where = self.env.cr.mogrify(
"SELECT FROM %s WHERE %s" % (from_clause, where_clause), "SELECT FROM {} WHERE {}".format(from_clause, where_clause), where_clause_params
where_clause_params) )
self.assertEqual( self.assertEqual(
complete_where, complete_where,
b'SELECT FROM "res_partner" WHERE ' b'SELECT FROM "res_partner" WHERE ' b'("res_partner"."name" % \'test\')',
b'("res_partner"."name" % \'test\')') )
def test_fuzzy_where_generation_translatable(self): def test_fuzzy_where_generation_translatable(self):
"""Check the generation of the where clause for translatable fields.""" """Check the generation of the where clause for translatable fields."""
ctx = {'lang': 'de_DE'} ctx = {"lang": "de_DE"}
# create new query with fuzzy search operator # create new query with fuzzy search operator
query = self.ResPartnerCategory.with_context(ctx)\ query = self.ResPartnerCategory.with_context(ctx)._where_calc(
._where_calc([('name', '%', 'Goschaeftlic')], active_test=False) [("name", "%", "Goschaeftlic")], active_test=False
)
from_clause, where_clause, where_clause_params = query.get_sql() from_clause, where_clause, where_clause_params = query.get_sql()
# the % parameter has to be escaped (%%) for the string replation # the % parameter has to be escaped (%%) for the string replation
self.assertIn("""SELECT id FROM temp_irt_current WHERE name %% %s""", self.assertIn(
where_clause) """SELECT id FROM temp_irt_current WHERE name %% %s""", where_clause
)
complete_where = self.env.cr.mogrify( complete_where = self.env.cr.mogrify(
"SELECT FROM %s WHERE %s" % (from_clause, where_clause), "SELECT FROM {} WHERE {}".format(from_clause, where_clause), where_clause_params
where_clause_params) )
self.assertIn( self.assertIn(
b"""SELECT id FROM temp_irt_current WHERE name % 'Goschaeftlic'""", b"""SELECT id FROM temp_irt_current WHERE name % 'Goschaeftlic'""",
complete_where) complete_where,
)
def test_fuzzy_order_generation(self): def test_fuzzy_order_generation(self):
"""Check the generation of the where clause.""" """Check the generation of the where clause."""
order = "similarity(%s.name, 'test') DESC" % self.ResPartner._table order = "similarity(%s.name, 'test') DESC" % self.ResPartner._table
query = self.ResPartner._where_calc( query = self.ResPartner._where_calc([("name", "%", "test")], active_test=False)
[('name', '%', 'test')], active_test=False)
order_by = self.ResPartner._generate_order_by(order, query) order_by = self.ResPartner._generate_order_by(order, query)
self.assertEqual(' ORDER BY %s' % order, order_by) self.assertEqual(" ORDER BY %s" % order, order_by)
def test_fuzzy_search(self): def test_fuzzy_search(self):
"""Test the fuzzy search itself.""" """Test the fuzzy search itself."""
if self.TrgmIndex._trgm_extension_exists() != 'installed': if self.TrgmIndex._trgm_extension_exists() != "installed":
return return
if not self.TrgmIndex.index_exists('res.partner', 'name'): if not self.TrgmIndex.index_exists("res.partner", "name"):
field_partner_name = self.env.ref('base.field_res_partner__name') field_partner_name = self.env.ref("base.field_res_partner__name")
self.TrgmIndex.create({ self.TrgmIndex.create(
'field_id': field_partner_name.id, {"field_id": field_partner_name.id, "index_type": "gin",}
'index_type': 'gin', )
})
partner1 = self.ResPartner.create({ partner1 = self.ResPartner.create({"name": "John Smith"})
'name': 'John Smith' partner2 = self.ResPartner.create({"name": "John Smizz"})
}) partner3 = self.ResPartner.create({"name": "Linus Torvalds"})
partner2 = self.ResPartner.create(
{'name': 'John Smizz'}
)
partner3 = self.ResPartner.create({
'name': 'Linus Torvalds'
})
res = self.ResPartner.search([('name', '%', 'Jon Smith')]) res = self.ResPartner.search([("name", "%", "Jon Smith")])
self.assertIn(partner1.id, res.ids) self.assertIn(partner1.id, res.ids)
self.assertIn(partner2.id, res.ids) self.assertIn(partner2.id, res.ids)
self.assertNotIn(partner3.id, res.ids) self.assertNotIn(partner3.id, res.ids)
res = self.ResPartner.search([('name', '%', 'Smith John')]) res = self.ResPartner.search([("name", "%", "Smith John")])
self.assertIn(partner1.id, res.ids) self.assertIn(partner1.id, res.ids)
self.assertIn(partner2.id, res.ids) self.assertIn(partner2.id, res.ids)
self.assertNotIn(partner3.id, res.ids) self.assertNotIn(partner3.id, res.ids)

View File

@ -1,6 +1,5 @@
<?xml version="1.0" encoding="utf-8"?> <?xml version="1.0" encoding="utf-8" ?>
<odoo> <odoo>
<record model="ir.ui.view" id="trgm_index_view_form"> <record model="ir.ui.view" id="trgm_index_view_form">
<field name="name">trgm.index.view.form</field> <field name="name">trgm.index.view.form</field>
<field name="model">trgm.index</field> <field name="model">trgm.index</field>
@ -8,27 +7,28 @@
<form string="Trigram Index"> <form string="Trigram Index">
<sheet> <sheet>
<group col="4"> <group col="4">
<field name="field_id" domain="[('ttype', 'in', ['char', 'text'])]"/> <field
<field name="index_name"/> name="field_id"
<field name="index_type"/> domain="[('ttype', 'in', ['char', 'text'])]"
/>
<field name="index_name" />
<field name="index_type" />
</group> </group>
</sheet> </sheet>
</form> </form>
</field> </field>
</record> </record>
<record model="ir.ui.view" id="trgm_index_view_tree"> <record model="ir.ui.view" id="trgm_index_view_tree">
<field name="name">trgm.index.view.tree</field> <field name="name">trgm.index.view.tree</field>
<field name="model">trgm.index</field> <field name="model">trgm.index</field>
<field name="arch" type="xml"> <field name="arch" type="xml">
<tree string="Trigram Index"> <tree string="Trigram Index">
<field name="field_id"/> <field name="field_id" />
<field name="index_name"/> <field name="index_name" />
<field name="index_type"/> <field name="index_type" />
</tree> </tree>
</field> </field>
</record> </record>
<record model="ir.actions.act_window" id="trgm_index_action"> <record model="ir.actions.act_window" id="trgm_index_action">
<field name="name">Trigram Index</field> <field name="name">Trigram Index</field>
<field name="res_model">trgm.index</field> <field name="res_model">trgm.index</field>
@ -36,10 +36,10 @@
<field name="view_mode">tree,form</field> <field name="view_mode">tree,form</field>
<field name="type">ir.actions.act_window</field> <field name="type">ir.actions.act_window</field>
</record> </record>
<menuitem
<menuitem id="trgm_index_menu" id="trgm_index_menu"
parent="base.next_id_9" parent="base.next_id_9"
action="trgm_index_action" action="trgm_index_action"
groups="base.group_no_one"/> groups="base.group_no_one"
/>
</odoo> </odoo>