Skip to content

Commit

Permalink
Fix using auto_paging_iter() with expand: [...] (#1434)
Browse files Browse the repository at this point in the history
* deduplicate querystring using a pre-made url

* fix tests
  • Loading branch information
xavdid-stripe authored Dec 17, 2024
1 parent 43d0937 commit ef9d5b0
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 11 deletions.
43 changes: 36 additions & 7 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Unpack,
)
import uuid
from urllib.parse import urlsplit, urlunsplit
from urllib.parse import urlsplit, urlunsplit, parse_qs

# breaking circular dependency
import stripe # noqa: IMP101
Expand Down Expand Up @@ -556,6 +556,35 @@ def _args_for_request_with_retries(
url,
)

params = params or {}
if params and (method == "get" or method == "delete"):
# if we're sending params in the querystring, then we have to make sure we're not
# duplicating anything we got back from the server already (like in a list iterator)
# so, we parse the querystring the server sends back so we can merge with what we (or the user) are trying to send
existing_params = {}
for k, v in parse_qs(urlsplit(url).query).items():
# note: server sends back "expand[]" but users supply "expand", so we strip the brackets from the key name
if k.endswith("[]"):
existing_params[k[:-2]] = v
else:
# all querystrings are pulled out as lists.
# We want to keep the querystrings that actually are lists, but flatten the ones that are single values
existing_params[k] = v[0] if len(v) == 1 else v

# if a user is expanding something that wasn't expanded before, add (and deduplicate) it
# this could theoretically work for other lists that we want to merge too, but that doesn't seem to be a use case
# it never would have worked before, so I think we can start with `expand` and go from there
if "expand" in existing_params and "expand" in params:
params["expand"] = list( # type:ignore - this is a dict
set([*existing_params["expand"], *params["expand"]])
)

params = {
**existing_params,
# user_supplied params take precedence over server params
**params,
}

encoded_params = urlencode(list(_api_encode(params or {}, api_mode)))

# Don't use strict form encoding by changing the square bracket control
Expand Down Expand Up @@ -586,13 +615,13 @@ def _args_for_request_with_retries(

if method == "get" or method == "delete":
if params:
query = encoded_params
scheme, netloc, path, base_query, fragment = urlsplit(abs_url)
# if we're sending query params, we've already merged the incoming ones with the server's "url"
# so we can overwrite the whole thing
scheme, netloc, path, _, fragment = urlsplit(abs_url)

if base_query:
query = "%s&%s" % (base_query, query)

abs_url = urlunsplit((scheme, netloc, path, query, fragment))
abs_url = urlunsplit(
(scheme, netloc, path, encoded_params, fragment)
)
post_data = None
elif method == "post":
if (
Expand Down
53 changes: 53 additions & 0 deletions tests/api_resources/test_list_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest

import stripe
from tests.http_client_mock import HTTPClientMock


class TestListObject(object):
Expand Down Expand Up @@ -439,6 +440,58 @@ def test_forwards_api_key_to_nested_resources(self, http_client_mock):
)
assert lo.data[0].api_key == "sk_test_iter_forwards_options"

def test_iter_with_params(self, http_client_mock: HTTPClientMock):
http_client_mock.stub_request(
"get",
path="/v1/invoices/upcoming/lines",
query_string="customer=cus_123&expand[0]=data.price&limit=1",
rbody=json.dumps(
{
"object": "list",
"data": [
{
"id": "prod_001",
"object": "product",
"price": {"object": "price", "id": "price_123"},
}
],
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand[]=data.price",
"has_more": True,
}
),
)
# second page
http_client_mock.stub_request(
"get",
path="/v1/invoices/upcoming/lines",
query_string="customer=cus_123&expand[0]=data.price&limit=1&starting_after=prod_001",
rbody=json.dumps(
{
"object": "list",
"data": [
{
"id": "prod_002",
"object": "product",
"price": {"object": "price", "id": "price_123"},
}
],
"url": "/v1/invoices/upcoming/lines?customer=cus_123&expand[]=data.price",
"has_more": False,
}
),
)

lo = stripe.Invoice.upcoming_lines(
api_key="sk_test_invoice_lines",
customer="cus_123",
expand=["data.price"],
limit=1,
)

seen = [item["id"] for item in lo.auto_paging_iter()]

assert seen == ["prod_001", "prod_002"]


class TestAutoPagingAsync:
@staticmethod
Expand Down
9 changes: 5 additions & 4 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,17 @@ def test_ordereddict_encoding(self):

def test_url_construction(self, requestor, http_client_mock):
CASES = (
("%s?foo=bar" % stripe.api_base, "", {"foo": "bar"}),
("%s?foo=bar" % stripe.api_base, "?", {"foo": "bar"}),
(f"{stripe.api_base}?foo=bar", "", {"foo": "bar"}),
(f"{stripe.api_base}?foo=bar", "?", {"foo": "bar"}),
(stripe.api_base, "", {}),
(
"%s/%%20spaced?foo=bar%$&baz=5" % stripe.api_base,
f"{stripe.api_base}/%20spaced?baz=5&foo=bar%24",
"/ spaced?foo=bar$",
{"baz": "5"},
),
# duplicate query params keys should be deduped
(
"%s?foo=bar&foo=bar" % stripe.api_base,
f"{stripe.api_base}?foo=bar",
"?foo=bar",
{"foo": "bar"},
),
Expand Down

0 comments on commit ef9d5b0

Please sign in to comment.