Snowflake で Python プロファイリング機能を活用してメモリボトルネックを改善する


estie で Staff Engineer をしている @Ryosuke839 です!前回のブログからしばらく間が開きましたが、ふたたび Snowflake ネタです!今回は Snowflake 上で Python を使っていて遭遇した問題と、その解決策についてご紹介します。

Snowflake で Python がメモリを使い果たしてしまう

estie では以前のブログ記事(post post)でもご紹介したように、dbt Python model を Snowflake 上で多用しています。

マルチプロダクト戦略のためにデータパイプラインの数も増加しており、dbt の実行時間も気になるようになってきたため dbt の並列数(threads)の値を上げてみたところ、次のようなエラーが頻発するようになってしまいました。

04:33:15 Database Error in model fuga (models/hoge/fuga.py)
000603 (XX000): SQL execution internal error:
Processing aborted due to error 300005:4227727749; incident 4305017.

Snowflake の Knowledge Base によると、これは warehouse のメモリが不足していると起こるエラーのようです。

調査

メモリ使用量の確認

Snowflake の Query Profile 機能を使い実際のメモリ使用量を確かめてみます。


Statistics の中の Max Python process memory usage を確認すると、このモデルは 1.2 GB のメモリを消費していることがわかりました。このモデルでは数百万行の大量のデータを処理しているのですが、すべてのデータはメモリに乗り切らないため以下のように to_local_iterator() を用いて行ごとに処理するというメモリの使用量に配慮した実装にしたつもりでした。

source_data = dbt.source("source", "foo")
for row in source_data.to_local_iterator():
  processed_data = process(row)
  if more_process(processed_data):
    ...

実装上改善できるところは見つからないためここで一旦手詰まりとなってしまいます。

Snowflake 上の Python プロファイリング

Snowflake の Preview Feature として、Python procedure の実行時間・メモリ使用量のプロファイリング が可能です。dbt Python model も Snowflake 上では一時 stored procedure として実行されるため、この仕組みでメモリ使用量をプロファイリングすることができます。

利用するにはプロファイリングの結果を保存する stage を作成したうえでセッションパラメータ PYTHON_PROFILER_TARGET_STAGE を設定し、Python stored procedure を実行すればよく、dbt Python model では pre_hook を使ってセッションパラメータを設定することができます。

dbt.config(
  python_version="3.11",
  ...,
  pre_hook=[
    "ALTER SESSION SET PYTHON_PROFILER_TARGET_STAGE = foo.bar.profile",
    "ALTER SESSION SET ACTIVE_PYTHON_PROFILER = 'MEMORY'",
  ],
)

実行後に stage を見てみると、プロファイリングの結果を格納した .mprof ファイルが保存されているのがわかります。


File: _udf_code.py
Function: model at line 441

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   441    341.3 MiB    341.3 MiB           1       def model(dbt, session):
   442    341.3 MiB      0.0 MiB           1         source_data = dbt.source("source", "foo")
   443    343.3 MiB      2.0 MiB           1         for row in source_data.to_local_iterator():
   444    874.9 MiB -1634678.5 MiB       87140           processed_data = process(row)
   445    870.9 MiB 49881229.4 MiB       65537           if more_process(processed_data):
   446    870.9 MiB -2484493.8 MiB      131072             ...

プロファイリングの結果をダウンロードし確認すると、to_local_iterator で 500 MB あまりが消費されていることがわかりました。この関数は Snowpark Python API の関数で、ドキュメントには

Unlike collect(), this method does not load all data into memory at once.

と書かれているためメモリを大量に消費するとは思えないものでした。

追記: なお、ライブラリ関数も PYTHON_PROFILER_MODULES パラメータを指定することでプロファイリング対象にもできますが、パフォーマンスが著しく悪化してしまい実行が終わらなくなってしまいました。

ソースコードを読む

さいわい Snowpark Python API は Python で記述されており、ソースコードも公開されているため自身で処理を追うことができます。

コードを追っていくと snowflake.snowpark._internal.server_connection.ServerConnection._to_data_or_iter が呼ばれていることがわかりました。snowpark-python/src/snowflake/snowpark/_internal/server_connection.py at v1.28.0 · snowflakedb/snowpark-python · GitHub

def _to_data_or_iter(
  self,
  results_cursor: SnowflakeCursor,
  to_pandas: bool = False,
  to_iter: bool = False,
  to_arrow: bool = False,
) -> Dict[str, Any]:
  qid = results_cursor.sfqid
  if to_iter:
    new_cursor = results_cursor.connection.cursor()
    new_cursor.get_results_from_sfqid(qid)
    results_cursor = new_cursor
  ...

to_local_iterator などから呼ばれた際には to_iter = True となり、get_results_from_sfqid が呼ばれることがわかります。この関数は Snowflake Connector for Python の関数で、やはり Python で記述されソースコードも公開されているためさらにコードを追っていくことができます。

コードを追っていくと snowflake.connector.result_set.result_set_iterator が呼ばれていることがわかりました。 snowflake-connector-python/src/snowflake/connector/result_set.py at v3.13.2 · snowflakedb/snowflake-connector-python · GitHub

def result_set_iterator(
  ...
) -> Iterator[dict | Exception] | Iterator[tuple | Exception] | Iterator[Table]:
  is_fetch_all = kw.pop("is_fetch_all", False)
  if is_fetch_all:
      ...
  else:
    with ThreadPoolExecutor(prefetch_thread_num) as pool:
      for _ in range(min(prefetch_thread_num, len(unfetched_batches))):
        unconsumed_batches.append(
          pool.submit(unfetched_batches.popleft().create_iter, **kw)
        )

      yield from first_batch_iter

      while unconsumed_batches:
        # Submit the next un-fetched batch to the pool
        if unfetched_batches:
          future = pool.submit(unfetched_batches.popleft().create_iter, **kw)
          unconsumed_batches.append(future)
        future = unconsumed_batches.popleft()
        # this will raise an exception if one has occurred
        batch_iterator = future.result()
        yield from batch_iterator

このコードを読み解くと、最初のバッチを返しつつ prefetch_thread_num スレッドで後続のバッチをフェッチしに行っていることがわかります。このバッチとは、クエリの結果が大きかった際に Snowflake SQL REST API が自動的に分割した結果を指しています。create_iter の中ではバッチをダウンロードしたうえで JSON をパースし dict に変換しているため、ここでバッチの大きさに応じてメモリが消費されてしまいます。

データの流れを図示すると以下のようになります。


この中で、Python Stored Procedure で囲われた範囲にあるデータが Python のメモリを消費しています。すなわち、Snowflake Connector for Python の中で prefetch_thread_num とバッチの大きさそれぞれに比例したメモリが消費されてしまっています。

解決編

パラメータの調整

prefetch_thread_num はセッションパラメータ CLIENT_PREFETCH_THREADS で設定することができ、バッチの大きさはセッションパラメータ RESULT_CHUNK_SIZE で設定することができます。ドキュメントでは

ほとんどのユーザーは、このパラメーターを設定する必要はありません。このパラメーターがユーザーによって設定されていない場合、ドライバーは上記で指定されたデフォルトで起動しますが、使用可能なメモリすべてを使い果たすことを避けるため、積極的に管理します。

と言及されてはいますが、コードを読む限り Python 系のドライバにおいては自動調整はサポートされていないようです。

これらはセッションパラメータなので、dbt ではプロファイリングの設定と同様に pre_hook で設定することができます。

dbt.config(
  python_version="3.11",
  ...,
  pre_hook=[
    "ALTER SESSION SET CLIENT_RESULT_CHUNK_SIZE = 16", # 16 (MB) が最小
    "ALTER SESSION SET CLIENT_PREFETCH_THREADS = 1", # 1 が最小
  ],
)
Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
   441    336.8 MiB    336.8 MiB           1       def model(dbt, session):
   442    336.8 MiB      0.0 MiB           1         source_data = dbt.source("source", "foo")
   443    356.8 MiB     20.0 MiB           1         for row in source_data.to_local_iterator():
   444    579.4 MiB -104320.0 MiB       21814           processed_data = process(row)
   445    579.4 MiB 8796390.5 MiB       16384           if more_process(processed_data):
   446    579.4 MiB -203284.5 MiB       32768             ...

この際 to_local_iteratorstatement_params={"USE_CACHED_RESULT": False}) と指定しないとクエリのキャッシュが効いてしまい CLIENT_RESULT_CHUNK_SIZE の変更が反映されないことがあるので注意が必要になります。

プロファイリング付きで実行を試してみると、to_local_iterator のメモリ使用量が 200 MB にまで抑えられていることが確認できました。


ちなみに、これらのパラメータの値を下げるとデータの読み出しが遅くなる可能性はあるものの、今回のユースケースではデータを読み出したあとの処理にかかる時間が支配的であったため実行時間に有意な差は出ませんでした。

今回の学び

  • プロファイリング機能を使うことでコードのボトルネックを特定することができる
  • 必ずしもドキュメントに全てが書かれているわけではない
    • ソースコードを読むと仕様を確実に知ることができる

最後に

今回はプロファイラを活用しライブラリのソースコードを追うことで問題の解明、解決を行いました。普段からここまで深い調査が必要になるわけではありませんが、膨大なデータを扱うコードを書くにあたりボトルネックは常に意識しておく必要はあります。

このようなことを考えるのが好きな方はぜひバックエンドエンジニア(データ)にご応募ください。そうでない方でも様々なポジションで採用中ですのでカジュアル面談お申し込みフォーム(エンジニア職種)からカジュアルにお話ししましょう。

hrmos.co

hrmos.co

© 2019- estie, inc.