import types import numpy as np import pandas as pd import pytest import altair as alt from .. import parse_shorthand, update_nested, infer_encoding_types from ..core import infer_dtype FAKE_CHANNELS_MODULE = ''' """Fake channels module for utility tests.""" from altair.utils import schemapi class FieldChannel(object): def __init__(self, shorthand, **kwargs): kwargs['shorthand'] = shorthand return super(FieldChannel, self).__init__(**kwargs) class ValueChannel(object): def __init__(self, value, **kwargs): kwargs['value'] = value return super(ValueChannel, self).__init__(**kwargs) class X(FieldChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "x" class XValue(ValueChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "x" class Y(FieldChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "y" class YValue(ValueChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "y" class StrokeWidth(FieldChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "strokeWidth" class StrokeWidthValue(ValueChannel, schemapi.SchemaBase): _schema = {} _encoding_name = "strokeWidth" ''' @pytest.mark.parametrize( "value,expected_type", [ ([1, 2, 3], "integer"), ([1.0, 2.0, 3.0], "floating"), ([1, 2.0, 3], "mixed-integer-float"), (["a", "b", "c"], "string"), (["a", "b", np.nan], "mixed"), ], ) def test_infer_dtype(value, expected_type): assert infer_dtype(value) == expected_type def test_parse_shorthand(): def check(s, **kwargs): assert parse_shorthand(s) == kwargs check("") # Fields alone check("foobar", field="foobar") check("blah:(fd ", field="blah:(fd ") # Fields with type check("foobar:quantitative", type="quantitative", field="foobar") check("foobar:nominal", type="nominal", field="foobar") check("foobar:ordinal", type="ordinal", field="foobar") check("foobar:temporal", type="temporal", field="foobar") check("foobar:geojson", type="geojson", field="foobar") check("foobar:Q", type="quantitative", field="foobar") check("foobar:N", type="nominal", field="foobar") check("foobar:O", type="ordinal", field="foobar") check("foobar:T", type="temporal", field="foobar") check("foobar:G", type="geojson", field="foobar") # Fields with aggregate and/or type check("average(foobar)", field="foobar", aggregate="average") check("min(foobar):temporal", type="temporal", field="foobar", aggregate="min") check("sum(foobar):Q", type="quantitative", field="foobar", aggregate="sum") # check that invalid arguments are not split-out check("invalid(blah)", field="invalid(blah)") check("blah:invalid", field="blah:invalid") check("invalid(blah):invalid", field="invalid(blah):invalid") # check parsing in presence of strange characters check( "average(a b:(c\nd):Q", aggregate="average", field="a b:(c\nd", type="quantitative", ) # special case: count doesn't need an argument check("count()", aggregate="count", type="quantitative") check("count():O", aggregate="count", type="ordinal") # time units: check("month(x)", field="x", timeUnit="month", type="temporal") check("year(foo):O", field="foo", timeUnit="year", type="ordinal") check("date(date):quantitative", field="date", timeUnit="date", type="quantitative") check( "yearmonthdate(field)", field="field", timeUnit="yearmonthdate", type="temporal" ) def test_parse_shorthand_with_data(): def check(s, data, **kwargs): assert parse_shorthand(s, data) == kwargs data = pd.DataFrame( { "x": [1, 2, 3, 4, 5], "y": ["A", "B", "C", "D", "E"], "z": pd.date_range("2018-01-01", periods=5, freq="D"), "t": pd.date_range("2018-01-01", periods=5, freq="D").tz_localize("UTC"), } ) check("x", data, field="x", type="quantitative") check("y", data, field="y", type="nominal") check("z", data, field="z", type="temporal") check("t", data, field="t", type="temporal") check("count(x)", data, field="x", aggregate="count", type="quantitative") check("count()", data, aggregate="count", type="quantitative") check("month(z)", data, timeUnit="month", field="z", type="temporal") check("month(t)", data, timeUnit="month", field="t", type="temporal") def test_parse_shorthand_all_aggregates(): aggregates = alt.Root._schema["definitions"]["AggregateOp"]["enum"] for aggregate in aggregates: shorthand = "{aggregate}(field):Q".format(aggregate=aggregate) assert parse_shorthand(shorthand) == { "aggregate": aggregate, "field": "field", "type": "quantitative", } def test_parse_shorthand_all_timeunits(): timeUnits = [] for loc in ["Local", "Utc"]: for typ in ["Single", "Multi"]: defn = loc + typ + "TimeUnit" timeUnits.extend(alt.Root._schema["definitions"][defn]["enum"]) for timeUnit in timeUnits: shorthand = "{timeUnit}(field):Q".format(timeUnit=timeUnit) assert parse_shorthand(shorthand) == { "timeUnit": timeUnit, "field": "field", "type": "quantitative", } def test_parse_shorthand_window_count(): shorthand = "count()" dct = parse_shorthand( shorthand, parse_aggregates=False, parse_window_ops=True, parse_timeunits=False, parse_types=False, ) assert dct == {"op": "count"} def test_parse_shorthand_all_window_ops(): window_ops = alt.Root._schema["definitions"]["WindowOnlyOp"]["enum"] aggregates = alt.Root._schema["definitions"]["AggregateOp"]["enum"] for op in window_ops + aggregates: shorthand = "{op}(field)".format(op=op) dct = parse_shorthand( shorthand, parse_aggregates=False, parse_window_ops=True, parse_timeunits=False, parse_types=False, ) assert dct == {"field": "field", "op": op} def test_update_nested(): original = {"x": {"b": {"foo": 2}, "c": 4}} update = {"x": {"b": {"foo": 5}, "d": 6}, "y": 40} output = update_nested(original, update, copy=True) assert output is not original assert output == {"x": {"b": {"foo": 5}, "c": 4, "d": 6}, "y": 40} output2 = update_nested(original, update) assert output2 is original assert output == output2 @pytest.fixture def channels(): channels = types.ModuleType("channels") exec(FAKE_CHANNELS_MODULE, channels.__dict__) return channels def _getargs(*args, **kwargs): return args, kwargs def test_infer_encoding_types(channels): expected = dict( x=channels.X("xval"), y=channels.YValue("yval"), strokeWidth=channels.StrokeWidthValue(value=4), ) # All positional args args, kwds = _getargs( channels.X("xval"), channels.YValue("yval"), channels.StrokeWidthValue(4) ) assert infer_encoding_types(args, kwds, channels) == expected # All keyword args args, kwds = _getargs(x="xval", y=alt.value("yval"), strokeWidth=alt.value(4)) assert infer_encoding_types(args, kwds, channels) == expected # Mixed positional & keyword args, kwds = _getargs( channels.X("xval"), channels.YValue("yval"), strokeWidth=alt.value(4) ) assert infer_encoding_types(args, kwds, channels) == expected def test_infer_encoding_types_with_condition(channels): args, kwds = _getargs( x=alt.condition("pred1", alt.value(1), alt.value(2)), y=alt.condition("pred2", alt.value(1), "yval"), strokeWidth=alt.condition("pred3", "sval", alt.value(2)), ) expected = dict( x=channels.XValue(2, condition=channels.XValue(1, test="pred1")), y=channels.Y("yval", condition=channels.YValue(1, test="pred2")), strokeWidth=channels.StrokeWidthValue( 2, condition=channels.StrokeWidth("sval", test="pred3") ), ) assert infer_encoding_types(args, kwds, channels) == expected