[IMP] base_partition: add batch method on base

pull/2615/head
nans 2020-09-08 20:17:41 +02:00 committed by hda
parent 5980c20a3d
commit c252a4e4f4
2 changed files with 85 additions and 40 deletions

View File

@ -2,22 +2,24 @@
# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl.html). # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl.html).
import fnmatch import fnmatch
from odoo import fields, models
from odoo import _, fields, models
from odoo.exceptions import UserError
from odoo.osv import expression from odoo.osv import expression
LIKE_COMPARATORS = ( LIKE_COMPARATORS = (
'like', "like",
'ilike', "ilike",
'=like', "=like",
'=ilike', "=ilike",
'not ilike', "not ilike",
'not like', "not like",
) )
class Base(models.AbstractModel): class Base(models.AbstractModel):
_inherit = 'base' _inherit = "base"
def partition(self, accessor): def partition(self, accessor):
"""Returns a dictionary forming a partition of self into a dictionary """Returns a dictionary forming a partition of self into a dictionary
@ -52,6 +54,19 @@ class Base(models.AbstractModel):
return partition return partition
def batch(self, batch_size=None):
"""Yield successive batches of size batch_size, or ."""
if not (batch_size or "_default_batch_size" in dir(self)):
raise UserError(
_(
"Either set up a '_default_batch_size' on the model"
" or provide a batch_size parameter."
)
)
batch_size = batch_size or self._default_batch_size
for i in range(0, len(self), batch_size):
yield self[i : i + batch_size]
def filtered_domain(self, domain): def filtered_domain(self, domain):
"""Backport from standard. """Backport from standard.
""" """
@ -59,11 +74,11 @@ class Base(models.AbstractModel):
return self return self
result = [] result = []
for d in reversed(domain): for d in reversed(domain):
if d == '|': if d == "|":
result.append(result.pop() | result.pop()) result.append(result.pop() | result.pop())
elif d == '!': elif d == "!":
result.append(self - result.pop()) result.append(self - result.pop())
elif d == '&': elif d == "&":
result.append(result.pop() & result.pop()) result.append(result.pop() & result.pop())
elif d == expression.TRUE_LEAF: elif d == expression.TRUE_LEAF:
result.append(self) result.append(self)
@ -71,40 +86,38 @@ class Base(models.AbstractModel):
result.append(self.browse()) result.append(self.browse())
else: else:
(key, comparator, value) = d (key, comparator, value) = d
if key.endswith('.id'): if key.endswith(".id"):
key = key[:-3] key = key[:-3]
if key == 'id': if key == "id":
key = '' key = ""
# determine the field with the final type for values # determine the field with the final type for values
field = None field = None
if key: if key:
model = self.browse() model = self.browse()
for fname in key.split('.'): for fname in key.split("."):
field = model._fields[fname] field = model._fields[fname]
model = model[fname] model = model[fname]
if comparator in LIKE_COMPARATORS: if comparator in LIKE_COMPARATORS:
value_esc = ( value_esc = (
value.replace('_', '?') value.replace("_", "?").replace("%", "*").replace("[", "?")
.replace('%', '*')
.replace('[', '?')
) )
records = self.browse() records = self.browse()
for rec in self: for rec in self:
data = rec.mapped(key) data = rec.mapped(key)
if comparator in ('child_of', 'parent_of'): if comparator in ("child_of", "parent_of"):
records = data.search([(data._parent_name, comparator, value)]) records = data.search([(data._parent_name, comparator, value)])
value = records.ids value = records.ids
comparator = 'in' comparator = "in"
if isinstance(data, models.BaseModel): if isinstance(data, models.BaseModel):
v = value v = value
if isinstance(value, (list, tuple)) and len(value): if isinstance(value, (list, tuple)) and len(value):
v = value[0] v = value[0]
if isinstance(v, str): if isinstance(v, str):
data = data.mapped('display_name') data = data.mapped("display_name")
else: else:
data = data.ids if data else [False] data = data.ids if data else [False]
elif field and field.type in ('date', 'datetime'): elif field and field.type in ("date", "datetime"):
# convert all date and datetime values to datetime # convert all date and datetime values to datetime
normalize = fields.Datetime.to_datetime normalize = fields.Datetime.to_datetime
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
@ -112,41 +125,43 @@ class Base(models.AbstractModel):
else: else:
value = normalize(value) value = normalize(value)
data = [normalize(d) for d in data] data = [normalize(d) for d in data]
if comparator in ('in', 'not in'): if comparator in ("in", "not in"):
if not (isinstance(value, list) or isinstance(value, tuple)): if not (isinstance(value, list) or isinstance(value, tuple)):
value = [value] value = [value]
if comparator == '=': if comparator == "=":
ok = value in data ok = value in data
elif comparator == 'in': elif comparator == "in":
ok = any(map(lambda x: x in data, value)) ok = any(map(lambda x: x in data, value))
elif comparator == '<': elif comparator == "<":
ok = any(map(lambda x: x is not None and x < value, data)) ok = any(map(lambda x: x is not None and x < value, data))
elif comparator == '>': elif comparator == ">":
ok = any(map(lambda x: x is not None and x > value, data)) ok = any(map(lambda x: x is not None and x > value, data))
elif comparator == '<=': elif comparator == "<=":
ok = any(map(lambda x: x is not None and x <= value, data)) ok = any(map(lambda x: x is not None and x <= value, data))
elif comparator == '>=': elif comparator == ">=":
ok = any(map(lambda x: x is not None and x >= value, data)) ok = any(map(lambda x: x is not None and x >= value, data))
elif comparator in ('!=', '<>'): elif comparator in ("!=", "<>"):
ok = value not in data ok = value not in data
elif comparator == 'not in': elif comparator == "not in":
ok = all(map(lambda x: x not in data, value)) ok = all(map(lambda x: x not in data, value))
elif comparator == 'not ilike': elif comparator == "not ilike":
ok = all(map(lambda x: value.lower() not in x.lower(), data)) ok = all(map(lambda x: value.lower() not in x.lower(), data))
elif comparator == 'ilike': elif comparator == "ilike":
data = [x.lower() for x in data] data = [x.lower() for x in data]
match = fnmatch.filter(data, '*'+(value_esc or '').lower()+'*') match = fnmatch.filter(
data, "*" + (value_esc or "").lower() + "*"
)
ok = bool(match) ok = bool(match)
elif comparator == 'not like': elif comparator == "not like":
ok = all(map(lambda x: value not in x, data)) ok = all(map(lambda x: value not in x, data))
elif comparator == 'like': elif comparator == "like":
ok = bool(fnmatch.filter(data, value and '*'+value_esc+'*')) ok = bool(fnmatch.filter(data, value and "*" + value_esc + "*"))
elif comparator == '=?': elif comparator == "=?":
ok = (value in data) or not value ok = (value in data) or not value
elif comparator == '=like': elif comparator == "=like":
ok = bool(fnmatch.filter(data, value_esc)) ok = bool(fnmatch.filter(data, value_esc))
elif comparator == '=ilike': elif comparator == "=ilike":
data = [x.lower() for x in data] data = [x.lower() for x in data]
ok = bool(fnmatch.filter(data, value and value_esc.lower())) ok = bool(fnmatch.filter(data, value and value_esc.lower()))
else: else:

View File

@ -2,7 +2,9 @@
# License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl). # License AGPL-3.0 or later (http://www.gnu.org/licenses/agpl).
import functools import functools
import math
from odoo.exceptions import UserError
from odoo.tests.common import TransactionCase from odoo.tests.common import TransactionCase
@ -86,6 +88,34 @@ class TestPartition(TransactionCase):
records = functools.reduce(sum, partition.values()) records = functools.reduce(sum, partition.values())
self.assertEqual(self.xyz, records) # we get the same recordset self.assertEqual(self.xyz, records) # we get the same recordset
def test_batch(self):
"""The sum of all batches should be the original recordset;
an empty recordset should return no batch;
without a batch parameter, the model's _default_batch_size should be used.
"""
records = self.xyz
batch_size = 2
assert len(records) # only makes sense with nonempty recordset
batches = list(records.batch(batch_size))
self.assertEqual(len(batches), math.ceil(len(records) / batch_size))
for batch in batches[:-1]:
self.assertEqual(len(batch), batch_size)
last_batch_size = len(records) % batch_size or batch_size
self.assertEqual(len(batches[-1]), last_batch_size)
self.assertEqual(functools.reduce(sum, batches), records)
empty_recordset = records.browse()
no_batches = list(empty_recordset.batch(batch_size))
self.assertEqual(no_batches, [])
with self.assertRaises(UserError):
list(records.batch())
records.__class__._default_batch_size = batch_size
batches_from_default = list(records.batch())
self.assertEqual(batches_from_default, batches)
def test_filtered_domain(self): def test_filtered_domain(self):
"""Initially yo satisfy the coverage tools, this test actually documents """Initially yo satisfy the coverage tools, this test actually documents
a number of pitfalls of filtered_domain and the differences with a search. a number of pitfalls of filtered_domain and the differences with a search.