diff --git a/base_partition/models/models.py b/base_partition/models/models.py index ed818ac24..d96613cb3 100644 --- a/base_partition/models/models.py +++ b/base_partition/models/models.py @@ -1,7 +1,19 @@ # © 2020 Acsone (http://www.acsone.eu) # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl.html). -from odoo import models +import fnmatch +from odoo import fields, models +from odoo.osv import expression + +LIKE_COMPARATORS = ( + 'like', + 'ilike', + '=like', + '=ilike', + 'not ilike', + 'not like', +) + class Base(models.AbstractModel): @@ -39,3 +51,109 @@ class Base(models.AbstractModel): partition[key] += record return partition + + def filtered_domain(self, domain): + """Backport from standard. + """ + if not domain: + return self + result = [] + for d in reversed(domain): + if d == '|': + result.append(result.pop() | result.pop()) + elif d == '!': + result.append(self - result.pop()) + elif d == '&': + result.append(result.pop() & result.pop()) + elif d == expression.TRUE_LEAF: + result.append(self) + elif d == expression.FALSE_LEAF: + result.append(self.browse()) + else: + (key, comparator, value) = d + if key.endswith('.id'): + key = key[:-3] + if key == 'id': + key = '' + # determine the field with the final type for values + field = None + if key: + model = self.browse() + for fname in key.split('.'): + field = model._fields[fname] + model = model[fname] + + if comparator in LIKE_COMPARATORS: + value_esc = ( + value.replace('_', '?') + .replace('%', '*') + .replace('[', '?') + ) + records = self.browse() + for rec in self: + data = rec.mapped(key) + if comparator in ('child_of', 'parent_of'): + records = data.search([(data._parent_name, comparator, value)]) + value = records.ids + comparator = 'in' + if isinstance(data, models.BaseModel): + v = value + if isinstance(value, (list, tuple)) and len(value): + v = value[0] + if isinstance(v, str): + data = data.mapped('display_name') + else: + data = data.ids if data else [False] + elif field and field.type in ('date', 'datetime'): + # convert all date and datetime values to datetime + normalize = fields.Datetime.to_datetime + if isinstance(value, (list, tuple)): + value = [normalize(v) for v in value] + else: + value = normalize(value) + data = [normalize(d) for d in data] + if comparator in ('in', 'not in'): + if not (isinstance(value, list) or isinstance(value, tuple)): + value = [value] + + if comparator == '=': + ok = value in data + elif comparator == 'in': + ok = any(map(lambda x: x in data, value)) + elif comparator == '<': + ok = any(map(lambda x: x is not None and x < value, data)) + elif comparator == '>': + ok = any(map(lambda x: x is not None and x > value, data)) + elif comparator == '<=': + ok = any(map(lambda x: x is not None and x <= value, data)) + elif comparator == '>=': + ok = any(map(lambda x: x is not None and x >= value, data)) + elif comparator in ('!=', '<>'): + ok = value not in data + elif comparator == 'not in': + ok = all(map(lambda x: x not in data, value)) + elif comparator == 'not ilike': + ok = all(map(lambda x: value.lower() not in x.lower(), data)) + elif comparator == 'ilike': + data = [x.lower() for x in data] + match = fnmatch.filter(data, '*'+(value_esc or '').lower()+'*') + ok = bool(match) + elif comparator == 'not like': + ok = all(map(lambda x: value not in x, data)) + elif comparator == 'like': + ok = bool(fnmatch.filter(data, value and '*'+value_esc+'*')) + elif comparator == '=?': + ok = (value in data) or not value + elif comparator == '=like': + ok = bool(fnmatch.filter(data, value_esc)) + elif comparator == '=ilike': + data = [x.lower() for x in data] + ok = bool(fnmatch.filter(data, value and value_esc.lower())) + else: + raise ValueError + if ok: + records |= rec + result.append(records) + while len(result) > 1: + result.append(result.pop() & result.pop()) + return result[0] diff --git a/base_partition/tests/test_partition.py b/base_partition/tests/test_partition.py index 942521db7..f8146d6fb 100644 --- a/base_partition/tests/test_partition.py +++ b/base_partition/tests/test_partition.py @@ -2,6 +2,7 @@ # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). import functools + from odoo.tests.common import TransactionCase @@ -14,33 +15,39 @@ class TestPartition(TransactionCase): self.c2 = self.Category.create({"name": "c2"}) self.c3 = self.Category.create({"name": "c3"}) - self.Partner = self.env['res.partner'] + self.Partner = self.env["res.partner"] self.parent1 = self.Partner.create({"name": "parent1"}) self.parent2 = self.Partner.create({"name": "parent2"}) self.child1 = self.Partner.create({"name": "child1"}) self.child2 = self.Partner.create({"name": "child2"}) self.child3 = self.Partner.create({"name": "child3"}) - self.x = self.Partner.create({ - "name": "x", - "customer": True, - "category_id": [(6, 0, [self.c1.id, self.c2.id])], - "child_ids": [(6, 0, [self.child1.id, self.child2.id])], - "parent_id": self.parent1.id, - }) - self.y = self.Partner.create({ - "name": "y", - "customer": False, - "category_id": [(6, 0, [self.c2.id, self.c3.id])], - "child_ids": [(6, 0, [self.child2.id, self.child3.id])], - "parent_id": self.parent2.id, - }) - self.z = self.Partner.create({ - "name": "z", - "customer": False, - "category_id": [(6, 0, [self.c1.id, self.c3.id])], - "child_ids": [(6, 0, [self.child1.id, self.child3.id])], - "parent_id": self.parent2.id, - }) + self.x = self.Partner.create( + { + "name": "x", + "customer": True, + "category_id": [(6, 0, [self.c1.id, self.c2.id])], + "child_ids": [(6, 0, [self.child1.id, self.child2.id])], + "parent_id": self.parent1.id, + } + ) + self.y = self.Partner.create( + { + "name": "y", + "customer": False, + "category_id": [(6, 0, [self.c2.id, self.c3.id])], + "child_ids": [(6, 0, [self.child2.id, self.child3.id])], + "parent_id": self.parent2.id, + } + ) + self.z = self.Partner.create( + { + "name": "z", + "customer": False, + "category_id": [(6, 0, [self.c1.id, self.c3.id])], + "child_ids": [(6, 0, [self.child1.id, self.child3.id])], + "parent_id": self.parent2.id, + } + ) self.xyz = self.x + self.y + self.z def test_partition_many2many(self): @@ -78,3 +85,60 @@ class TestPartition(TransactionCase): records = functools.reduce(sum, partition.values()) self.assertEqual(self.xyz, records) # we get the same recordset + + def test_filtered_domain(self): + """Initially yo satisfy the coverage tools, this test actually documents + a number of pitfalls of filtered_domain and the differences with a search. + Commented examples would cause warnings, and even though these are edge-cases + these behaviours should be known. + """ + + records = self.xyz + empty_recordset = records.browse() + + def filtered_search(domain): + search = self.xyz.search(domain) + return search.filtered(lambda r: r.id in self.xyz.ids) + + self.assertEqual(records, records.filtered_domain([])) + self.assertEqual(empty_recordset, records.filtered_domain([(0, "=", 1)])) + + for field in ["name"]: + for r in self.xyz: + domain = [(field, "=", r[field])] + self.assertEqual(self.xyz.filtered_domain(domain), r) + self.assertEqual(filtered_search(domain), r) + + domain = [(field, "in", r[field])] + self.assertTrue(self.xyz.filtered_domain(domain), r) + with self.assertRaises(ValueError): + filtered_search(domain) + + for field in ["customer"]: + for r in [self.x, self.y | self.z]: + value = r[0][field] + domain = [(field, "=", value)] + self.assertEqual(self.xyz.filtered_domain(domain), r) + self.assertEqual(filtered_search(domain), r) + # domain = [(field, "in", value)] + # self.assertEqual(self.xyz.filtered_domain(domain), r) + # expected_result = r if value else empty_recordset # ! + # self.assertEqual(filtered_search(domain), expected_result) + + for field in ["parent_id"]: + for r in [self.x, self.y | self.z]: + domain = [(field, "=", r[0][field].id)] + self.assertEqual(self.xyz.filtered_domain(domain), r) + self.assertEqual(filtered_search(domain), r) + domain = [(field, "in", r[0][field].ids)] + self.assertEqual(self.xyz.filtered_domain(domain), r) + self.assertEqual(filtered_search(domain), r) + + for r in self.xyz: + field = "category_id" + in_domain = [(field, "in", r[field].ids)] + self.assertEqual(self.xyz.filtered_domain(in_domain), self.xyz) + self.assertEqual(self.xyz.search(in_domain), self.xyz) + # eq_domain = [(field, "=", r[field].ids)] + # self.assertEqual(self.xyz.search(eq_domain), self.xyz) + # self.assertEqual(self.xyz.filtered_domain(eq_domain), empty_recordset)