From c252a4e4f45ab543fd8d2c94bb96ccd4ce37ec0c Mon Sep 17 00:00:00 2001 From: nans Date: Tue, 8 Sep 2020 20:17:41 +0200 Subject: [PATCH] [IMP] base_partition: add batch method on base --- base_partition/models/models.py | 95 +++++++++++++++----------- base_partition/tests/test_partition.py | 30 ++++++++ 2 files changed, 85 insertions(+), 40 deletions(-) diff --git a/base_partition/models/models.py b/base_partition/models/models.py index d96613cb3..8100969cd 100644 --- a/base_partition/models/models.py +++ b/base_partition/models/models.py @@ -2,22 +2,24 @@ # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl.html). import fnmatch -from odoo import fields, models + +from odoo import _, fields, models +from odoo.exceptions import UserError from odoo.osv import expression LIKE_COMPARATORS = ( - 'like', - 'ilike', - '=like', - '=ilike', - 'not ilike', - 'not like', + "like", + "ilike", + "=like", + "=ilike", + "not ilike", + "not like", ) class Base(models.AbstractModel): - _inherit = 'base' + _inherit = "base" def partition(self, accessor): """Returns a dictionary forming a partition of self into a dictionary @@ -52,6 +54,19 @@ class Base(models.AbstractModel): return partition + def batch(self, batch_size=None): + """Yield successive batches of size batch_size, or .""" + if not (batch_size or "_default_batch_size" in dir(self)): + raise UserError( + _( + "Either set up a '_default_batch_size' on the model" + " or provide a batch_size parameter." + ) + ) + batch_size = batch_size or self._default_batch_size + for i in range(0, len(self), batch_size): + yield self[i : i + batch_size] + def filtered_domain(self, domain): """Backport from standard. """ @@ -59,11 +74,11 @@ class Base(models.AbstractModel): return self result = [] for d in reversed(domain): - if d == '|': + if d == "|": result.append(result.pop() | result.pop()) - elif d == '!': + elif d == "!": result.append(self - result.pop()) - elif d == '&': + elif d == "&": result.append(result.pop() & result.pop()) elif d == expression.TRUE_LEAF: result.append(self) @@ -71,40 +86,38 @@ class Base(models.AbstractModel): result.append(self.browse()) else: (key, comparator, value) = d - if key.endswith('.id'): + if key.endswith(".id"): key = key[:-3] - if key == 'id': - key = '' + if key == "id": + key = "" # determine the field with the final type for values field = None if key: model = self.browse() - for fname in key.split('.'): + for fname in key.split("."): field = model._fields[fname] model = model[fname] if comparator in LIKE_COMPARATORS: value_esc = ( - value.replace('_', '?') - .replace('%', '*') - .replace('[', '?') + value.replace("_", "?").replace("%", "*").replace("[", "?") ) records = self.browse() for rec in self: data = rec.mapped(key) - if comparator in ('child_of', 'parent_of'): + if comparator in ("child_of", "parent_of"): records = data.search([(data._parent_name, comparator, value)]) value = records.ids - comparator = 'in' + comparator = "in" if isinstance(data, models.BaseModel): v = value if isinstance(value, (list, tuple)) and len(value): v = value[0] if isinstance(v, str): - data = data.mapped('display_name') + data = data.mapped("display_name") else: data = data.ids if data else [False] - elif field and field.type in ('date', 'datetime'): + elif field and field.type in ("date", "datetime"): # convert all date and datetime values to datetime normalize = fields.Datetime.to_datetime if isinstance(value, (list, tuple)): @@ -112,41 +125,43 @@ class Base(models.AbstractModel): else: value = normalize(value) data = [normalize(d) for d in data] - if comparator in ('in', 'not in'): + if comparator in ("in", "not in"): if not (isinstance(value, list) or isinstance(value, tuple)): value = [value] - if comparator == '=': + if comparator == "=": ok = value in data - elif comparator == 'in': + elif comparator == "in": ok = any(map(lambda x: x in data, value)) - elif comparator == '<': + elif comparator == "<": ok = any(map(lambda x: x is not None and x < value, data)) - elif comparator == '>': + elif comparator == ">": ok = any(map(lambda x: x is not None and x > value, data)) - elif comparator == '<=': + elif comparator == "<=": ok = any(map(lambda x: x is not None and x <= value, data)) - elif comparator == '>=': + elif comparator == ">=": ok = any(map(lambda x: x is not None and x >= value, data)) - elif comparator in ('!=', '<>'): + elif comparator in ("!=", "<>"): ok = value not in data - elif comparator == 'not in': + elif comparator == "not in": ok = all(map(lambda x: x not in data, value)) - elif comparator == 'not ilike': + elif comparator == "not ilike": ok = all(map(lambda x: value.lower() not in x.lower(), data)) - elif comparator == 'ilike': + elif comparator == "ilike": data = [x.lower() for x in data] - match = fnmatch.filter(data, '*'+(value_esc or '').lower()+'*') + match = fnmatch.filter( + data, "*" + (value_esc or "").lower() + "*" + ) ok = bool(match) - elif comparator == 'not like': + elif comparator == "not like": ok = all(map(lambda x: value not in x, data)) - elif comparator == 'like': - ok = bool(fnmatch.filter(data, value and '*'+value_esc+'*')) - elif comparator == '=?': + elif comparator == "like": + ok = bool(fnmatch.filter(data, value and "*" + value_esc + "*")) + elif comparator == "=?": ok = (value in data) or not value - elif comparator == '=like': + elif comparator == "=like": ok = bool(fnmatch.filter(data, value_esc)) - elif comparator == '=ilike': + elif comparator == "=ilike": data = [x.lower() for x in data] ok = bool(fnmatch.filter(data, value and value_esc.lower())) else: diff --git a/base_partition/tests/test_partition.py b/base_partition/tests/test_partition.py index f8146d6fb..352ac1824 100644 --- a/base_partition/tests/test_partition.py +++ b/base_partition/tests/test_partition.py @@ -2,7 +2,9 @@ # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). import functools +import math +from odoo.exceptions import UserError from odoo.tests.common import TransactionCase @@ -86,6 +88,34 @@ class TestPartition(TransactionCase): records = functools.reduce(sum, partition.values()) self.assertEqual(self.xyz, records) # we get the same recordset + def test_batch(self): + """The sum of all batches should be the original recordset; + an empty recordset should return no batch; + without a batch parameter, the model's _default_batch_size should be used. + """ + records = self.xyz + batch_size = 2 + + assert len(records) # only makes sense with nonempty recordset + batches = list(records.batch(batch_size)) + self.assertEqual(len(batches), math.ceil(len(records) / batch_size)) + for batch in batches[:-1]: + self.assertEqual(len(batch), batch_size) + last_batch_size = len(records) % batch_size or batch_size + self.assertEqual(len(batches[-1]), last_batch_size) + self.assertEqual(functools.reduce(sum, batches), records) + + empty_recordset = records.browse() + no_batches = list(empty_recordset.batch(batch_size)) + self.assertEqual(no_batches, []) + + with self.assertRaises(UserError): + list(records.batch()) + + records.__class__._default_batch_size = batch_size + batches_from_default = list(records.batch()) + self.assertEqual(batches_from_default, batches) + def test_filtered_domain(self): """Initially yo satisfy the coverage tools, this test actually documents a number of pitfalls of filtered_domain and the differences with a search.