""" Tools for generating forms based on SQLAlchemy models. """ import inspect from wtforms import fields as wtforms_fields from wtforms import validators from wtforms.form import Form from .fields import QuerySelectField from .fields import QuerySelectMultipleField __all__ = ( "model_fields", "model_form", ) def converts(*args): def _inner(func): func._converter_for = frozenset(args) return func return _inner class ModelConversionError(Exception): def __init__(self, message): Exception.__init__(self, message) class ModelConverterBase: def __init__(self, converters, use_mro=True): self.use_mro = use_mro if not converters: converters = {} for name in dir(self): obj = getattr(self, name) if hasattr(obj, "_converter_for"): for classname in obj._converter_for: converters[classname] = obj self.converters = converters def get_converter(self, column): """ Searches `self.converters` for a converter method with an argument that matches the column's type. """ if self.use_mro: types = inspect.getmro(type(column.type)) else: types = [type(column.type)] # Search by module + name for col_type in types: type_string = f"{col_type.__module__}.{col_type.__name__}" # remove the 'sqlalchemy.' prefix for sqlalchemy <0.7 compatibility if type_string.startswith("sqlalchemy."): type_string = type_string[11:] if type_string in self.converters: return self.converters[type_string] # Search by name for col_type in types: if col_type.__name__ in self.converters: return self.converters[col_type.__name__] raise ModelConversionError( "Could not find field converter for column %s (%r)." % (column.name, types[0]) ) def convert(self, model, mapper, prop, field_args, db_session=None): if not hasattr(prop, "columns") and not hasattr(prop, "direction"): return elif not hasattr(prop, "direction") and len(prop.columns) != 1: raise TypeError( "Do not know how to convert multiple-column properties currently" ) kwargs = { "validators": [], "filters": [], "default": None, "description": prop.doc, } if field_args: kwargs.update(field_args) if kwargs["validators"]: # Copy to prevent modifying nested mutable values of the original kwargs["validators"] = list(kwargs["validators"]) converter = None column = None if not hasattr(prop, "direction"): column = prop.columns[0] # Support sqlalchemy.schema.ColumnDefault, so users can benefit # from setting defaults for fields, e.g.: # field = Column(DateTimeField, default=datetime.utcnow) default = getattr(column, "default", None) if default is not None: # Only actually change default if it has an attribute named # 'arg' that's callable. callable_default = getattr(default, "arg", None) if callable_default is not None: # ColumnDefault(val).arg can be also a plain value default = ( callable_default(None) if callable(callable_default) else callable_default ) kwargs["default"] = default if column.nullable: kwargs["validators"].append(validators.Optional()) else: kwargs["validators"].append(validators.Required()) converter = self.get_converter(column) else: # We have a property with a direction. if db_session is None: raise ModelConversionError( "Cannot convert field %s, need DB session." % prop.key ) foreign_model = prop.mapper.class_ nullable = True for pair in prop.local_remote_pairs: if not pair[0].nullable: nullable = False kwargs.update( { "allow_blank": nullable, "query_factory": lambda: db_session.query(foreign_model).all(), } ) converter = self.converters[prop.direction.name] return converter( model=model, mapper=mapper, prop=prop, column=column, field_args=kwargs ) class ModelConverter(ModelConverterBase): def __init__(self, extra_converters=None, use_mro=True): super().__init__(extra_converters, use_mro=use_mro) @classmethod def _string_common(cls, column, field_args, **extra): if isinstance(column.type.length, int) and column.type.length: field_args["validators"].append(validators.Length(max=column.type.length)) @converts("String") # includes Unicode def conv_String(self, field_args, **extra): self._string_common(field_args=field_args, **extra) return wtforms_fields.StringField(**field_args) @converts("Text", "LargeBinary", "Binary") # includes UnicodeText def conv_Text(self, field_args, **extra): self._string_common(field_args=field_args, **extra) return wtforms_fields.TextAreaField(**field_args) @converts("Boolean", "dialects.mssql.base.BIT") def conv_Boolean(self, field_args, **extra): return wtforms_fields.BooleanField(**field_args) @converts("Date") def conv_Date(self, field_args, **extra): return wtforms_fields.DateField(**field_args) @converts("DateTime") def conv_DateTime(self, field_args, **extra): return wtforms_fields.DateTimeField(**field_args) @converts("Enum") def conv_Enum(self, column, field_args, **extra): field_args["choices"] = [(e, e) for e in column.type.enums] return wtforms_fields.SelectField(**field_args) @converts("Integer") # includes BigInteger and SmallInteger def handle_integer_types(self, column, field_args, **extra): unsigned = getattr(column.type, "unsigned", False) if unsigned: field_args["validators"].append(validators.NumberRange(min=0)) return wtforms_fields.IntegerField(**field_args) @converts("Numeric") # includes DECIMAL, Float/FLOAT, REAL, and DOUBLE def handle_decimal_types(self, column, field_args, **extra): # override default decimal places limit, use database defaults instead field_args.setdefault("places", None) return wtforms_fields.DecimalField(**field_args) @converts("dialects.mysql.types.YEAR", "dialects.mysql.base.YEAR") def conv_MSYear(self, field_args, **extra): field_args["validators"].append(validators.NumberRange(min=1901, max=2155)) return wtforms_fields.StringField(**field_args) @converts("dialects.postgresql.base.INET") def conv_PGInet(self, field_args, **extra): field_args.setdefault("label", "IP Address") field_args["validators"].append(validators.IPAddress()) return wtforms_fields.StringField(**field_args) @converts("dialects.postgresql.base.MACADDR") def conv_PGMacaddr(self, field_args, **extra): field_args.setdefault("label", "MAC Address") field_args["validators"].append(validators.MacAddress()) return wtforms_fields.StringField(**field_args) @converts("dialects.postgresql.base.UUID") def conv_PGUuid(self, field_args, **extra): field_args.setdefault("label", "UUID") field_args["validators"].append(validators.UUID()) return wtforms_fields.StringField(**field_args) @converts("MANYTOONE") def conv_ManyToOne(self, field_args, **extra): return QuerySelectField(**field_args) @converts("MANYTOMANY", "ONETOMANY") def conv_ManyToMany(self, field_args, **extra): return QuerySelectMultipleField(**field_args) def model_fields( model, db_session=None, only=None, exclude=None, field_args=None, converter=None, exclude_pk=False, exclude_fk=False, ): """ Generate a dictionary of fields for a given SQLAlchemy model. See `model_form` docstring for description of parameters. """ mapper = model._sa_class_manager.mapper converter = converter or ModelConverter() field_args = field_args or {} properties = [] for prop in mapper.iterate_properties: if getattr(prop, "columns", None): if exclude_fk and prop.columns[0].foreign_keys: continue elif exclude_pk and prop.columns[0].primary_key: continue properties.append((prop.key, prop)) # ((p.key, p) for p in mapper.iterate_properties) if only: properties = (x for x in properties if x[0] in only) elif exclude: properties = (x for x in properties if x[0] not in exclude) field_dict = {} for name, prop in properties: field = converter.convert(model, mapper, prop, field_args.get(name), db_session) if field is not None: field_dict[name] = field return field_dict def model_form( model, db_session=None, base_class=Form, only=None, exclude=None, field_args=None, converter=None, exclude_pk=True, exclude_fk=True, type_name=None, ): """ Create a wtforms Form for a given SQLAlchemy model class:: from wtalchemy.orm import model_form from myapp.models import User UserForm = model_form(User) :param model: A SQLAlchemy mapped model class. :param db_session: An optional SQLAlchemy Session. :param base_class: Base form class to extend from. Must be a ``wtforms.Form`` subclass. :param only: An optional iterable with the property names that should be included in the form. Only these properties will have fields. :param exclude: An optional iterable with the property names that should be excluded from the form. All other properties will have fields. :param field_args: An optional dictionary of field names mapping to keyword arguments used to construct each field object. :param converter: A converter to generate the fields based on the model properties. If not set, ``ModelConverter`` is used. :param exclude_pk: An optional boolean to force primary key exclusion. :param exclude_fk: An optional boolean to force foreign keys exclusion. :param type_name: An optional string to set returned type name. """ if not hasattr(model, "_sa_class_manager"): raise TypeError("model must be a sqlalchemy mapped model") type_name = type_name or str(model.__name__ + "Form") field_dict = model_fields( model, db_session, only, exclude, field_args, converter, exclude_pk=exclude_pk, exclude_fk=exclude_fk, ) return type(type_name, (base_class,), field_dict)