diff options
author | Chayim I. Kirshen <c@kirshen.com> | 2021-11-29 20:07:20 +0200 |
---|---|---|
committer | Chayim I. Kirshen <c@kirshen.com> | 2021-11-29 20:07:20 +0200 |
commit | 39fc550251d238cdba7966ff153321ca9e488508 (patch) | |
tree | e79360ec70feac7f0ab992813f8b2d43f7c67bab /tests/test_search.py | |
parent | a924269502b96dc71339cca3dfb20aaa3899a9d0 (diff) | |
parent | 4db85ef574a64a2b230a3ae1ff19c9d04065a114 (diff) | |
download | redis-py-ck-linkdocs.tar.gz |
merging masterck-linkdocs
Diffstat (limited to 'tests/test_search.py')
-rw-r--r-- | tests/test_search.py | 365 |
1 files changed, 307 insertions, 58 deletions
diff --git a/tests/test_search.py b/tests/test_search.py index d1fc75f..b65ac8d 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -82,8 +82,8 @@ def createIndex(client, num_docs=100, definition=None): try: client.create_index( (TextField("play", weight=5.0), - TextField("txt"), - NumericField("chapter")), + TextField("txt"), + NumericField("chapter")), definition=definition, ) except redis.ResponseError: @@ -320,8 +320,8 @@ def test_stopwords(client): def test_filters(client): client.ft().create_index( (TextField("txt"), - NumericField("num"), - GeoField("loc")) + NumericField("num"), + GeoField("loc")) ) client.ft().add_document( "doc1", @@ -379,7 +379,7 @@ def test_payloads_with_no_content(client): def test_sort_by(client): client.ft().create_index( (TextField("txt"), - NumericField("num", sortable=True)) + NumericField("num", sortable=True)) ) client.ft().add_document("doc1", txt="foo bar", num=1) client.ft().add_document("doc2", txt="foo baz", num=2) @@ -424,7 +424,7 @@ def test_example(client): # Creating the index definition and schema client.ft().create_index( (TextField("title", weight=5.0), - TextField("body")) + TextField("body")) ) # Indexing a document @@ -552,8 +552,8 @@ def test_no_index(client): def test_partial(client): client.ft().create_index( (TextField("f1"), - TextField("f2"), - TextField("f3")) + TextField("f2"), + TextField("f3")) ) client.ft().add_document("doc1", f1="f1_val", f2="f2_val") client.ft().add_document("doc2", f1="f1_val", f2="f2_val") @@ -574,8 +574,8 @@ def test_partial(client): def test_no_create(client): client.ft().create_index( (TextField("f1"), - TextField("f2"), - TextField("f3")) + TextField("f2"), + TextField("f3")) ) client.ft().add_document("doc1", f1="f1_val", f2="f2_val") client.ft().add_document("doc2", f1="f1_val", f2="f2_val") @@ -604,8 +604,8 @@ def test_no_create(client): def test_explain(client): client.ft().create_index( (TextField("f1"), - TextField("f2"), - TextField("f3")) + TextField("f2"), + TextField("f3")) ) res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") assert res @@ -629,8 +629,8 @@ def test_summarize(client): doc = sorted(client.ft().search(q).docs)[0] assert "<b>Henry</b> IV" == doc.play assert ( - "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt + "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) q = Query("king henry").paging(0, 1).summarize().highlight() @@ -638,8 +638,8 @@ def test_summarize(client): doc = sorted(client.ft().search(q).docs)[0] assert "<b>Henry</b> ... " == doc.play assert ( - "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt + "ACT I SCENE I. London. The palace. Enter <b>KING</b> <b>HENRY</b>, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt ) @@ -812,10 +812,10 @@ def test_spell_check(client): res = client.ft().spellcheck("lorm", include="dict") assert len(res["lorm"]) == 3 assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") # test spellcheck exclude @@ -873,7 +873,7 @@ def test_scorer(client): ) client.ft().add_document( "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa + description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa ) # default scorer is TFIDF @@ -930,7 +930,7 @@ def test_config(client): @pytest.mark.redismod -def test_aggregations(client): +def test_aggregations_groupby(client): # Creating the index definition and schema client.ft().create_index( ( @@ -967,36 +967,242 @@ def test_aggregations(client): req = aggregations.AggregateRequest("redis").group_by( "@parent", reducers.count(), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinct("@title"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.count_distinctish("@title"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.sum("@random_num"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "21" # 10+8+3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.min("@random_num"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3" # min(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.max("@random_num"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" # max(10,8,3) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.avg("@random_num"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "7" # (10+3+8)/3 + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.stddev("random_num"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "3.60555127546" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.quantile("@random_num", 0.5), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == "10" + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", reducers.tolist("@title"), - reducers.first_value("@title"), - reducers.random_sample("@title", 2), ) + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[3] == ["RediSearch", "RedisAI", "RedisJson"] + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", + reducers.first_value("@title").alias("first"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res == ['parent', 'redis', 'first', 'RediSearch'] + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", + reducers.random_sample("@title", 2).alias("random"), + ) + + res = client.ft().aggregate(req).rows[0] + assert res[1] == "redis" + assert res[2] == "random" + assert len(res[3]) == 2 + assert res[3][0] in ["RediSearch", "RedisAI", "RedisJson"] + + +@pytest.mark.redismod +def test_aggregations_sort_by_and_limit(client): + client.ft().create_index( + ( + TextField("t1"), + TextField("t2"), + ) + ) + + client.ft().client.hset("doc1", mapping={'t1': 'a', 't2': 'b'}) + client.ft().client.hset("doc2", mapping={'t1': 'b', 't2': 'a'}) + + # test sort_by using SortDirection + req = aggregations.AggregateRequest("*") \ + .sort_by(aggregations.Asc("@t2"), aggregations.Desc("@t1")) + res = client.ft().aggregate(req) + assert res.rows[0] == ['t2', 'a', 't1', 'b'] + assert res.rows[1] == ['t2', 'b', 't1', 'a'] + + # test sort_by without SortDirection + req = aggregations.AggregateRequest("*") \ + .sort_by("@t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ['t1', 'a'] + assert res.rows[1] == ['t1', 'b'] + + # test sort_by with max + req = aggregations.AggregateRequest("*") \ + .sort_by("@t1", max=1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + + # test limit + req = aggregations.AggregateRequest("*") \ + .sort_by("@t1").limit(1, 1) + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ['t1', 'b'] + + +@pytest.mark.redismod +def test_aggregations_load(client): + client.ft().create_index( + ( + TextField("t1"), + TextField("t2"), + ) + ) + + client.ft().client.hset("doc1", mapping={'t1': 'hello', 't2': 'world'}) + + # load t1 + req = aggregations.AggregateRequest("*").load("t1") + res = client.ft().aggregate(req) + assert res.rows[0] == ['t1', 'hello'] + + # load t2 + req = aggregations.AggregateRequest("*").load("t2") + res = client.ft().aggregate(req) + assert res.rows[0] == ['t2', 'world'] + + +@pytest.mark.redismod +def test_aggregations_apply(client): + client.ft().create_index( + ( + TextField("PrimaryKey", sortable=True), + NumericField("CreatedDateTimeUTC", sortable=True), + ) + ) + + client.ft().client.hset( + "doc1", + mapping={ + 'PrimaryKey': '9::362330', + 'CreatedDateTimeUTC': '637387878524969984' + } + ) + client.ft().client.hset( + "doc2", + mapping={ + 'PrimaryKey': '9::362329', + 'CreatedDateTimeUTC': '637387875859270016' + } + ) + + req = aggregations.AggregateRequest("*") \ + .apply(CreatedDateTimeUTC='@CreatedDateTimeUTC * 10') res = client.ft().aggregate(req) + assert res.rows[0] == ['CreatedDateTimeUTC', '6373878785249699840'] + assert res.rows[1] == ['CreatedDateTimeUTC', '6373878758592700416'] - res = res.rows[0] - assert len(res) == 26 - assert "redis" == res[1] - assert "3" == res[3] - assert "3" == res[5] - assert "3" == res[7] - assert "21" == res[9] - assert "3" == res[11] - assert "10" == res[13] - assert "7" == res[15] - assert "3.60555127546" == res[17] - assert "10" == res[19] - assert ["RediSearch", "RedisAI", "RedisJson"] == res[21] - assert "RediSearch" == res[23] - assert 2 == len(res[25]) + +@pytest.mark.redismod +def test_aggregations_filter(client): + client.ft().create_index( + ( + TextField("name", sortable=True), + NumericField("age", sortable=True), + ) + ) + + client.ft().client.hset( + "doc1", + mapping={ + 'name': 'bar', + 'age': '25' + } + ) + client.ft().client.hset( + "doc2", + mapping={ + 'name': 'foo', + 'age': '19' + } + ) + + req = aggregations.AggregateRequest("*") \ + .filter("@name=='foo' && @age < 20") + res = client.ft().aggregate(req) + assert len(res.rows) == 1 + assert res.rows[0] == ['name', 'foo', 'age', '19'] + + req = aggregations.AggregateRequest("*") \ + .filter("@age > 15").sort_by("@age") + res = client.ft().aggregate(req) + assert len(res.rows) == 2 + assert res.rows[0] == ['age', '19'] + assert res.rows[1] == ['age', '25'] @pytest.mark.redismod @@ -1020,25 +1226,25 @@ def test_index_definition(client): ) assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args + "ON", + "JSON", + "PREFIX", + 2, + "hset:", + "henry", + "FILTER", + "@f1==32", + "LANGUAGE_FIELD", + "play", + "LANGUAGE", + "English", + "SCORE_FIELD", + "chapter", + "SCORE", + 0.5, + "PAYLOAD_FIELD", + "txt", + ] == definition.args createIndex(client.ft(), num_docs=500, definition=definition) @@ -1313,3 +1519,46 @@ def test_json_with_jsonpath(client): assert res.docs[0].id == "doc:1" with pytest.raises(Exception): res.docs[0].name_unsupported + + +@pytest.mark.redismod +def test_profile(client): + client.ft().create_index((TextField('t'),)) + client.ft().client.hset('1', 't', 'hello') + client.ft().client.hset('2', 't', 'world') + + # check using Query + q = Query('hello|world').no_content() + res, det = client.ft().profile(q) + assert det['Iterators profile']['Counter'] == 2.0 + assert len(det['Iterators profile']['Child iterators']) == 2 + assert det['Iterators profile']['Type'] == 'UNION' + assert det['Parsing time'] < 0.3 + assert len(res.docs) == 2 # check also the search result + + # check using AggregateRequest + req = aggregations.AggregateRequest("*").load("t")\ + .apply(prefix="startswith(@t, 'hel')") + res, det = client.ft().profile(req) + assert det['Iterators profile']['Counter'] == 2.0 + assert det['Iterators profile']['Type'] == 'WILDCARD' + assert det['Parsing time'] < 0.3 + assert len(res.rows) == 2 # check also the search result + + +@pytest.mark.redismod +def test_profile_limited(client): + client.ft().create_index((TextField('t'),)) + client.ft().client.hset('1', 't', 'hello') + client.ft().client.hset('2', 't', 'hell') + client.ft().client.hset('3', 't', 'help') + client.ft().client.hset('4', 't', 'helowa') + + q = Query('%hell% hel*') + res, det = client.ft().profile(q, limited=True) + assert det['Iterators profile']['Child iterators'][0]['Child iterators'] \ + == 'The number of iterators in the union is 3' + assert det['Iterators profile']['Child iterators'][1]['Child iterators'] \ + == 'The number of iterators in the union is 4' + assert det['Iterators profile']['Type'] == 'INTERSECT' + assert len(res.docs) == 3 # check also the search result |