Open14

ml_mdmを動かしてみる

kaakaakaakaa

Appleから画像生成AIフレームワークがリリースされたというニュースを見たので、手元のMac Studio(M2 Max)で動かしてみる。

https://news.yahoo.co.jp/articles/3ab04e6514859bfe914c880f60b9ea0576d839a9

アップルは8月9日、新しい画像生成AIモデル「Matryoshka Diffusion Models(MDM)」およびフレームワーク「ml-mdm」をGitHubで公開。高品質な画像生成AIを効率的に訓練、実行することが可能になった

kaakaakaakaa

https://github.com/apple/ml-mdm

とりあえずREADME通りにインストール。

$ git clone https://github.com/apple/ml-mdm
$ cd ml-mdm
$ pyenv local 3.10.14
$ pip install .

# よくわからないので、アセットはとりあえず64x64のみダウンロード
# → 最終的には256x256, 1024x1024もダウンロードした
$ export ASSET_PATH=https://docs-assets.developer.apple.com/ml-research/models/mdm
$ curl $ASSET_PATH/flickr64/vis_model.pth --output vis_model_64x64.pth
kaakaakaakaa

実行 (エラー: UserWarning: Failed to initialize NumPy: _ARRAY_API not found)

とりあえず実行してみるとエラー。

$ pyenv shell 3.10.14
$ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port 8888

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/Users/user/.pyenv/versions/3.10.14/bin/torchrun", line 5, in <module>
    from torch.distributed.run import main
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/__init__.py", line 1477, in <module>
    from .functional import *  # noqa: F403
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/functional.py", line 9, in <module>
    import torch.nn.functional as F
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/__init__.py", line 1, in <module>
    from .modules import *  # noqa: F403
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/__init__.py", line 35, in <module>
    from .transformer import TransformerEncoder, TransformerDecoder, \
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/transformer.py", line 20, in <module>
    device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),
/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/transformer.py:20: UserWarning: Failed to initialize NumPy: _ARRAY_API not found (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/tensor_numpy.cpp:84.)
  device: torch.device = torch.device(torch._C._get_default_device()),  # torch.device('cpu'),
[2024-08-11 10:19:33,051] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.

A module that was compiled using NumPy 1.x cannot be run in
...

NumPyのinitilize時に_ARRAY_APIが見つからないとのこと。 詳しい原因は実行ログに書いてある。

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

NumPy 1系でコンパイルしたモジュールをNumPy 2 系で動かそうとしたのが原因らしいので、とりあえず実行時のNumPyのバージョンを1系に固定してみるとエラーが解消されたので良しとする。

pyproject.toml
@@ -22,7 +32,7 @@ dependencies = [
     "imageio[ffmpeg]",
     "matplotlib",
     "mlx-data",
-    "numpy",
+    "numpy<2",
     "pytorch-model-summary",
     "rotary-embedding-torch",
kaakaakaakaa

実行 (エラー: ModuleNotFoundError: No module named 'ml_mdm.language_models')

次のエラー。

$ torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port 8888
[2024-08-11 10:21:18,085] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.
Traceback (most recent call last):
  File "/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py", line 19, in <module>
    from ml_mdm import helpers, reader
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/reader.py", line 13, in <module>
    from ml_mdm.language_models.tokenizer import Tokenizer
ModuleNotFoundError: No module named 'ml_mdm.language_models'
[2024-08-11 10:21:23,324] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 22243) of binary: /Users/user/.pyenv/versions/3.10.14/bin/python3.10
Traceback (most recent call last):
  File "/Users/user/.pyenv/versions/3.10.14/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 347, in wrapper
    return f(*args, **kwargs)
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/distributed/run.py", line 812, in main
    run(args)
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/distributed/run.py", line 803, in run
    elastic_launch(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 135, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 268, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
...

ml_mdmのサブモジュールlanguage_modelsが見つからないとのエラー。

ModuleNotFoundError: No module named 'ml_mdm.language_models'

確認してみると、確かに language_models だけでなく、ml_mdmのサブモジュールがインストールされてなさそう。

$ find ~/.pyenv.versions/3.10.14 -name "ml_mdm"
/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm
$ ls /Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm
__init__.py		config.py       distributed.py	        helpers.py      reader.py       samlers.py
__pycache__		diffusion.py	generate_html.py	lr_scaler.py	s3_helpers.py	trailer.py
$ find ~/.pyenv/versions -name "language_models"   
$

pyproject.tomlにサブモジュールを明記し、pip install .を実行してから再度実行するとエラーが解消された。

pyproject.toml
@@ -3,7 +3,17 @@ requires = ["setuptools>=70.0.0"]
 build-backend = "setuptools.build_meta"
 
 [tool.setuptools]
-packages = ["ml_mdm"]
+# packages = ["ml_mdm"]
+packages = [
+  "ml_mdm",
+  "ml_mdm.clis",
+  "ml_mdm.language_models",
+  "ml_mdm.models",
+  "ml_mdm.utils"
+]
kaakaakaakaa

実行

ようやく立ち上がった。

$  torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port 8888
[2024-08-11 10:44:55,385] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.
Running on local URL:  http://0.0.0.0:8888
[10:44:58] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/httpx/_client.py:1013} INFO - HTTP Request: GET http://localhost:8888/startup-events "HTTP/1.1 200 OK"
[10:44:58] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/httpx/_client.py:1013} INFO - HTTP Request: GET https://checkip.amazonaws.com/ "HTTP/1.1 200 "
[10:44:59] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/httpx/_client.py:1013} INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
[10:45:00] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/httpx/_client.py:1013} INFO - HTTP Request: HEAD http://localhost:8888/ "HTTP/1.1 200 OK"

Input Promptに適当なプロンプトを入れてRunボタンを押してみる。

初回実行時は10GBほどのモデルダウンロードなどが行われるため、時間がかかる。

実行ログ
...
Postive: Generate an icon for Apple's gen-image system / Negative: 
[10:45:17] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/wrappers/dataclass_wrapper.py:212} INFO - group.add_argument(*['--smaller_side_size', '--reader_config.smaller_side_size'], **{'required': False, 'dest': 'reader_config.smaller_side_size', 'default': 64, 'help': 'Smaller side is resized to this value', 'type': <class 'int'>})
[10:45:17] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/wrappers/dataclass_wrapper.py:212} INFO - group.add_argument(*['--max_caption_length', '--reader_config.max_caption_length'], **{'required': False, 'dest': 'reader_config.max_caption_length', 'default': 512, 'help': 'Maximum length of captions', 'type': <class 'int'>})
...(snip)...
[10:45:17] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination diffusion_config
[10:45:17] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/helpers.py:13} INFO - python ml_mdm/clis/generate_sample.py \
	 sample_dir=/mnt/data/samples\
	 sample_image_size=64\
	 loglevel=INFO\
	 device=cuda\
	 fp16=0\
	 seed=-1\
	 output_dir=/mnt/data/outputs\
	 vocab_file=data/t5.vocab\
	 pretrained_vision_file=None\
	 categorical_conditioning=0\
	 text_model=google/flan-t5-xl\
	 model=unet\
	 use_precomputed_text_embeddings=0\
	 batch_size=32\
	 num_training_steps=5000\
	 num_epochs=20000\
	 config_path=['configs/models/cc12m_64x64.yaml']\
	 avg_lm_steps=0\
	 dataset_config=configs/datasets/cc12m.yaml\
	 gradient_clip_norm=2\
	 log_freq=50\
	 loss_factor=1\
	 loss_target_type=HA_STYLE\
	 lr=5e-05\
	 metrics=fid,clip\
	 min_examples=10000\
	 model_output_scale=0\
	 name=cc12m_64x64\
	 num_diffusion_steps=1000\
	 num_eval_batches=500\
	 num_gradient_accumulations=1\
	 predict_variances=False\
	 prediction_length=129\
	 prediction_type=V_PREDICTION\
	 reader_config_file=launch_scripts/reader/latest_eval.yaml\
	 reproject_signal=False\
	 save_freq=5000\
	 schedule_type=DEEPFLOYD\
	 test_file_list=validation.tsv\
	 use_adamw=True\
	 use_lm_mask=1\
	 use_vdm_loss_weights=False\
	 vision_model=unet\
	 warmup_steps=10000\
	 reader_config=ReaderConfig(smaller_side_size=64, max_caption_length=512, max_token_length=128, image_size=64, random_crop=False, num_kept_files=-1, num_readers=16, shuffle_buffer_size=500, reader_buffer_size=500, endpoint_url='', bucket='mlx', prepad_caption_with_space=True, use_tokenizer_scores=True, prepad_bos=False, append_eos=True, padding_token='<pad>', pad_to_max_length=False)\
	 unet_config=UNetConfig(num_resnets_per_resolution=[2, 2, 2], temporal_dim=None, attention_levels=[1, 2], num_attention_layers=[0, 1, 5], num_temporal_attention_layers=None, conditioning_feature_dim=-1, conditioning_feature_proj_dim=2048, num_lm_head_layers=0, masked_cross_attention=0, resolution_channels=[256, 512, 768], skip_mid_blocks=False, skip_cond_emb=False, nesting=False, micro_conditioning='scale:64', temporal_mode=False, temporal_spatial_ds=False, temporal_positional_encoding=False, resnet_config=ResNetConfig(num_channels=-1, output_channels=-1, num_groups_norm=32, dropout=0.0, use_attention_ffn=True))\
	 diffusion_config=DiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=1000, reproject_signal=False, schedule_type=deepfloyd, prediction_type=v_prediction, loss_target_type=ddpm, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0.0, use_vdm_loss_weights=False)
config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.44k/1.44k [00:00<00:00, 2.27MB/s]
model.safetensors.index.json: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 53.0k/53.0k [00:00<00:00, 16.7MB/s]
model-00001-of-00002.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████| 9.45G/9.45G [02:26<00:00, 64.6MB/s]
model-00002-of-00002.safetensors: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1.95G/1.95G [00:24<00:00, 78.5MB/s]
Downloading shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [02:51<00:00, 85.91s/it]███████████████████████████████████████████████| 1.95G/1.95G [00:24<00:00, 85.5MB/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.94it/s]
generation_config.json: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [00:00<00:00, 539kB/s]
[10:48:13] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py:92} INFO - Diffusion config: DiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=1000, reproject_signal=False, schedule_type=deepfloyd, prediction_type=v_prediction, loss_target_type=ddpm, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0.0, use_vdm_loss_weights=False)
[10:48:13] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py:179} INFO - Step gammas: tensor([1.0000e+00, 9.9996e-01, 9.9991e-01,  ..., 9.7150e-06, 2.4288e-06,
        2.4288e-09])
[10:48:13] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_64x64.pth

cuda環境でなく、cpuが使用されているためsample生成も時間がかかる。 そして40分弱待たされた上でエラー。 (warningログなのでエラーではないかもしれない。席離れて戻った時にWebUI上でErrorとなっていたので処理エラーかと思っていたが、単にブラウザからのセッションが切れてエラー表示になっていただけかも)
(途中席を離れたため、別時刻での実行ログ)

[14:09:32] {/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py:220} INFO - Starting to sample from the model
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [38:46<00:00, 46.52s/it]
[14:48:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/imageio_ffmpeg/_io.py:561} WARNING - IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (266, 266) to (272, 272) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).
kaakaakaakaa

高速化 (CPU → Apple Silicon)

毎回結果が出るまで数十分かかるのは厳しく、せっかくM2 Maxマシンを使っているので、高速化できないか調べる。

TensorflowのデバイスにApple Silicon使うには、mpsを指定すると良いっぽい。
https://stackoverflow.com/a/72293634

変更してみる。

ml_mdm/clis/generate_sample.py
@@ -20,7 +20,7 @@ from ml_mdm import helpers, reader
 from ml_mdm.config import get_arguments, get_model, get_pipeline
 from ml_mdm.language_models import factory
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device("cuda" if torch.cuda.is_available() else "mps")

上記の変更を加えた上で再起動し、生成を再実行。

[14:55:23] {/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py:220} INFO - Starting to sample from the model
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [01:24<00:00,  1.69s/it]
[14:56:48] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/imageio_ffmpeg/_io.py:561} WARNING - IMAGEIO FFMPEG_WRITER WARNING: input image is not divisible by macro_block_size=16, resizing from (266, 266) to (272, 272) to ensure video compatibility with most codecs and players. To prevent resizing, make your input image divisible by the macro_block_size or set the macro_block_size to 1 (risking incompatibility).

cpu利用で40分弱かかっていたのが1分半にまで短縮。
Appleが出してるんだから、この辺り最初からケアしていてくれると嬉しいのだが。

kaakaakaakaa

実行

1分程度で画像が生成されるようになった。(64x64サイズだが)

WebUIのShow diffusion path as a videoの設定が有効化されているため、動画生成の過程も見れる。

kaakaakaakaa

256x256実行 (エラー: AssertionError)

Select the config fileLoad checkpoint の設定を64x64から256x256に変えて実行してみる。

実行ログ
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['unet_config.inner_config.resnet_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination unet_config.inner_config.resnet_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['unet_config.resnet_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination unet_config.resnet_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['unet_config.inner_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination unet_config.inner_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['diffusion_config.sampler_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination diffusion_config.sampler_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['reader_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination reader_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['unet_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination unet_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:831} INFO - Instantiating the wrapper with destinations ['diffusion_config']
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/simple_parsing/parsing.py:834} INFO - Instantiating the dataclass at destination diffusion_config
[15:33:18] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/helpers.py:13} INFO - python ml_mdm/clis/generate_sample.py \
	 sample_dir=samples\
	 sample_image_size=-1\
	 loglevel=INFO\
	 device=cuda\
	 fp16=0\
	 seed=-1\
	 output_dir=/mnt/data/outputs\
	 vocab_file=data/t5.vocab\
	 pretrained_vision_file=vis_model_256x256.pth\
	 categorical_conditioning=0\
	 text_model=google/flan-t5-xl\
	 model=nested_unet\
	 use_precomputed_text_embeddings=0\
	 batch_size=2\
	 num_training_steps=1000000\
	 num_epochs=20000\
	 config_path=['configs/models/cc12m_256x256.yaml']\
	 avg_lm_steps=0\
	 dataset_config=configs/datasets/cc12m.yaml\
	 gradient_clip_norm=2\
	 log_freq=50\
	 loss_factor=1\
	 loss_target_type=DDPM\
	 lr=5e-05\
	 metrics=fid,clip\
	 min_examples=10000\
	 mixed_ratio=2:1\
	 model_output_scale=0\
	 name=cc12m_256x256\
	 no_use_residual=True\
	 num_diffusion_steps=1000\
	 num_gradient_accumulations=1\
	 prediction_length=129\
	 prediction_type=V_PREDICTION\
	 random_low_noise=True\
	 reproject_signal=False\
	 rescale_signal=1\
	 sample-dir=/mnt/data/samples\
	 sample_image-size=256\
	 save_freq=5000\
	 schedule_shifted=True\
	 schedule_type=DEEPFLOYD\
	 skip_normalization=True\
	 test_file_list=validation.tsv\
	 use_double_loss=True\
	 use_lm_mask=1\
	 use_vdm_loss_weights=False\
	 vision_model=nested_unet\
	 warmup_steps=10000\
	 reader_config=ReaderConfig(smaller_side_size=256, max_caption_length=512, max_token_length=128, image_size=256, random_crop=False, num_kept_files=-1, num_readers=2, shuffle_buffer_size=2000, reader_buffer_size=2000, endpoint_url='', bucket='mlx', prepad_caption_with_space=True, use_tokenizer_scores=True, prepad_bos=False, append_eos=True, padding_token='<pad>', pad_to_max_length=False)\
	 unet_config=NestedUNetConfig(num_resnets_per_resolution=[2, 2, 1], temporal_dim=1024, attention_levels=[], num_attention_layers=[0, 0, 0], num_temporal_attention_layers=None, conditioning_feature_dim=-1, conditioning_feature_proj_dim=-1, num_lm_head_layers=0, masked_cross_attention=1, resolution_channels=[64, 128, 256], skip_mid_blocks=True, skip_cond_emb=True, nesting=False, micro_conditioning='scale:256', temporal_mode=False, temporal_spatial_ds=False, temporal_positional_encoding=False, resnet_config=ResNetConfig(num_channels=-1, output_channels=-1, num_groups_norm=32, dropout=0.0, use_attention_ffn=False), inner_config=UNetConfig(num_resnets_per_resolution=[2, 2, 2], temporal_dim=None, attention_levels=[1, 2], num_attention_layers=[0, 1, 5], num_temporal_attention_layers=None, conditioning_feature_dim=-1, conditioning_feature_proj_dim=2048, num_lm_head_layers=0, masked_cross_attention=0, resolution_channels=[256, 512, 768], skip_mid_blocks=False, skip_cond_emb=False, nesting=True, micro_conditioning='scale:64', temporal_mode=False, temporal_spatial_ds=False, temporal_positional_encoding=False, resnet_config=ResNetConfig(num_channels=-1, output_channels=-1, num_groups_norm=32, dropout=0.0, use_attention_ffn=True)), skip_inner_unet_input=False, skip_normalization=True, initialize_inner_with_pretrained='None', freeze_inner_unet=False, interp_conditioning=False)\
	 diffusion_config=NestedDiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=32, reproject_signal=False, schedule_type=ddpm, prediction_type=ddpm, loss_target_type=None, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0, use_vdm_loss_weights=True, use_double_loss=False, multi_res_weights=None, no_use_residual=False, use_random_interp=False, mixed_ratio=None, random_downsample=False, average_downsample=False, mid_downsample=False)
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.36it/s]
<-- load pretrained checkpoint error -->
No module named 'distributed'
[15:33:24] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py:285} INFO - Diffusion config: NestedDiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=32, reproject_signal=False, schedule_type=ddpm, prediction_type=ddpm, loss_target_type=None, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0, use_vdm_loss_weights=True, use_double_loss=False, multi_res_weights=None, no_use_residual=False, use_random_interp=False, mixed_ratio=None, random_downsample=False, average_downsample=False, mid_downsample=False)
[15:33:24] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py:179} INFO - Step gammas: tensor([1.0000, 0.9999, 0.9992, 0.9978, 0.9958, 0.9931, 0.9898, 0.9859, 0.9814,
        0.9762, 0.9705, 0.9642, 0.9573, 0.9498, 0.9418, 0.9332, 0.9241, 0.9146,
        0.9045, 0.8939, 0.8829, 0.8715, 0.8597, 0.8475, 0.8349, 0.8219, 0.8086,
        0.7951, 0.7812, 0.7671, 0.7527, 0.7382, 0.7234])
[15:33:24] {/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_256x256.pth
Traceback (most recent call last):
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/route_utils.py", line 288, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1931, in process_api
    result = await self.call_function(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1528, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 671, in async_iteration
    return await iterator.__anext__()
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 664, in __anext__
    return await anyio.to_thread.run_sync(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 647, in run_sync_iterator_async
    return next(iterator)
  File "/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 809, in gen_wrapper
    response = next(iterator)
  File "/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py", line 196, in generate
    assert args.sample_image_size != -1
AssertionError

元ファイルのタイポが原因...。

configs/models/cc12m_256x256.yaml
@@ -4,7 +4,7 @@ dataset_config: configs/datasets/cc12m.yaml
 min_examples: 10000
 sample-dir: /mnt/data/samples
 # batch-size: 32
-sample_image-size: 256
+sample_image_size: 256
 test_file_list: validation.tsv
 #reader-config-file: configs/datasets/reader_config_eval.yaml
 # shared_arguments
kaakaakaakaa

256x256実行 (エラー: AttributeError: 'NestedDiffusionConfig' object has no attribute 'mixed_batch'. Did you mean: 'mixed_ratio'?)

256x256のconfigファイルが信用できなくなったところで次のエラー。

実行ログ
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.69it/s]
<-- load pretrained checkpoint error -->
No module named 'distributed'
[15:45:16] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py:285} INFO - Diffusion config: NestedDiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=1000, reproject_signal=False, schedule_type=deepfloyd, prediction_type=v_prediction, loss_target_type=ddpm, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0.0, use_vdm_loss_weights=False, use_double_loss=False, multi_res_weights=None, no_use_residual=False, use_random_interp=False, mixed_ratio=None, random_downsample=False, average_downsample=False, mid_downsample=False)
[15:45:16] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py:179} INFO - Step gammas: tensor([1.0000e+00, 9.9996e-01, 9.9991e-01,  ..., 9.7150e-06, 2.4288e-06,
        2.4288e-09])
[15:45:16] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_256x256.pth
[15:45:16] {/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py:220} INFO - Starting to sample from the model
  0%|                                                                                                                                                                                                                           | 0/50 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/route_utils.py", line 288, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1931, in process_api
    result = await self.call_function(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1528, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 671, in async_iteration
    return await iterator.__anext__()
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 664, in __anext__
    return await anyio.to_thread.run_sync(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 647, in run_sync_iterator_async
    return next(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 809, in gen_wrapper
    response = next(iterator)
  File "/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py", line 222, in generate
    for step, result in enumerate(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 525, in _sample
    x0, x_t, extra = self.get_xt_minus_1(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 655, in get_xt_minus_1
    p_t = self.forward_model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 752, in forward_model
    p_t = model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py", line 264, in forward
    self.diffusion_config.mixed_batch is None
AttributeError: 'NestedDiffusionConfig' object has no attribute 'mixed_batch'. Did you mean: 'mixed_ratio'?

NestedDiffusionConfigの型定義に存在しない変数 (mixed_batch) にアクセスしようとしているとのこと。
https://github.com/apple/ml-mdm/blob/main/ml_mdm/diffusion.py#L262-L266

Unsupportedな設定を弾くためだけのロジックなので、コメントアウト。

mixed_batchが使われているのはここだけで、そもそも configs/models/cc12m_256x256.yamldiffusion_config が定義されてないから、コメントアウトしても問題なし。

ml_mdm/diffusion.py
@@ -257,12 +257,11 @@ class NestedModel(Model):
                 torch.cat([p, p.new_zeros(batch_size - p.size(0), *p.size()[1:])], 0)
                 for p in p_t
             ]

         # recompute the noise from pred_low
         if not self.diffusion_config.no_use_residual:
-            assert (
-                self.diffusion_config.mixed_batch is None
-            ), "do not support mixed-batch"
+            # assert (
+            #     self.diffusion_config.mixed_batch is None
+            # ), "do not support mixed-batch"
             x_t, x_t_low = x_t
kaakaakaakaa

256x256実行 (エラー: RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 3)

テンソルのサイズが違うというエラー。もう知らんがなという感じ。

実行ログ
[16:05:40] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py:179} INFO - Step gammas: tensor([1.0000, 0.9999, 0.9992, 0.9978, 0.9958, 0.9931, 0.9898, 0.9859, 0.9814,
        0.9762, 0.9705, 0.9642, 0.9573, 0.9498, 0.9418, 0.9332, 0.9241, 0.9146,
        0.9045, 0.8939, 0.8829, 0.8715, 0.8597, 0.8475, 0.8349, 0.8219, 0.8086,
        0.7951, 0.7812, 0.7671, 0.7527, 0.7382, 0.7234])
[16:05:40] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_256x256.pth
[16:05:40] {/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py:220} INFO - Starting to sample from the model
  0%|                                                                                                                                                                                                                           | 0/50 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/route_utils.py", line 288, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1931, in process_api
    result = await self.call_function(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1528, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 671, in async_iteration
    return await iterator.__anext__()
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 664, in __anext__
    return await anyio.to_thread.run_sync(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 647, in run_sync_iterator_async
    return next(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 809, in gen_wrapper
    response = next(iterator)
  File "/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py", line 222, in generate
    for step, result in enumerate(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 525, in _sample
    x0, x_t, extra = self.get_xt_minus_1(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 655, in get_xt_minus_1
    p_t = self.forward_model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 752, in forward_model
    p_t = model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py", line 267, in forward
    pred_x0_low, _ = self.sampler.get_x0_eps_from_pred(x_t_low, pred_low, times)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 342, in get_x0_eps_from_pred
    x0 = (x_t - pred * (1 - g).sqrt()) / g.sqrt()
RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 3

そもそも、上記ログの少し前の所得CheckpointのLoadに失敗しているのがよくない気がしてきた。

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  5.55it/s]
<-- load pretrained checkpoint error -->
No module named 'distributed'

distributedモジュールとは一体...。

kaakaakaakaa

公開されてるdistributed をインストールして実行すると、以下のエラーになる。

...
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.25it/s]
<-- load pretrained checkpoint error -->
cannot import name 'get_local_rank' from 'distributed' (/Users/user/.pyenv/versions/3.10.14/lib/python3.10/site-packages/distributed/__init__.py)
...

get_local_rank メソッドの存在を期待しているらしい。
ml_mdm/distributed.py の中に get_local_rankメソッドがあるので、これが元々distributedモジュールとして存在していた気がする。
https://github.com/apple/ml-mdm/blob/main/ml_mdm/distributed.py#L67-L68

distributedディレクトリを作ってml_mdm/distributed.py__init__.pyとして放り込んで、pyproject.tomldistributed モジュール宣言して pip install .

./
├─ distributed
|   └─ __init__.py ...$ cp ml_mdm/distributed.py __init__.py
|
├─ ml_mdm/ 
|   ├─ ...
|   └─ distributed.py

実行したらエラー内容は変わったけど、引き続きエラー。

Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  2.99it/s]
[16:46:49] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: None
<-- load pretrained checkpoint error -->
[Errno 2] No such file or directory: 'None'

試しに1024x1024を実行してみたけどダメ。

To create a public link, set `share=True` in `launch()`.
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.32it/s]
[16:49:31] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: 8rwvbg85tt
<-- load pretrained checkpoint error -->
[Errno 2] No such file or directory: '8rwvbg85tt'

読み込もうとしているモデルはconfigファイルで指定されている模様。

https://github.com/apple/ml-mdm/blob/main/configs/models/cc12m_1024x1024.yaml#L40

一応これを、README内でダウンロードするよう指示される vis_model_256x256.pth にしてみたけど、それもまたエラー。

実行ログ
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  7.59it/s]
[16:50:55] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_256x256.pth
{'inner_unet.up_blocks.1.resnets.0.norm1.weight', ...}
<-- load pretrained checkpoint error -->
Error(s) in loading state_dict for UNet:
	size mismatch for conv_in.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([256, 3, 3, 3]).
    ...
kaakaakaakaa

https://github.com/apple/ml-mdm/blob/main/configs/models/cc12m_256x256.yaml#L39
initialize_inner_with_pretrainedについては None でなく null を指定してあげたらエラーは発生しなくなった。

https://github.com/apple/ml-mdm/blob/main/configs/models/cc12m_1024x1024.yaml#L46
1024x1024の方が null で指定していたため、それを採用。

ただ、テンソルサイズの差異に関するエラーは解消できてない。

...
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00,  3.34it/s]
[17:29:44] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py:284} INFO - Diffusion config: NestedDiffusionConfig(sampler_config=SamplerConfig(num_diffusion_steps=32, reproject_signal=False, schedule_type=ddpm, prediction_type=ddpm, loss_target_type=None, beta_start=0.0001, beta_end=0.02, threshold_function=clip, rescale_schedule=1.0, rescale_signal=None, schedule_shifted=False), model_output_scale=0, use_vdm_loss_weights=True, use_double_loss=False, multi_res_weights=None, no_use_residual=False, use_random_interp=False, mixed_ratio=None, random_downsample=False, average_downsample=False, mid_downsample=False)
[17:29:44] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py:179} INFO - Step gammas: tensor([1.0000, 0.9999, 0.9992, 0.9978, 0.9958, 0.9931, 0.9898, 0.9859, 0.9814,
        0.9762, 0.9705, 0.9642, 0.9573, 0.9498, 0.9418, 0.9332, 0.9241, 0.9146,
        0.9045, 0.8939, 0.8829, 0.8715, 0.8597, 0.8475, 0.8349, 0.8219, 0.8086,
        0.7951, 0.7812, 0.7671, 0.7527, 0.7382, 0.7234])
[17:29:44] {/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/models/unet.py:803} INFO - Loading model file: vis_model_256x256.pth
[17:29:45] {/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py:220} INFO - Starting to sample from the model
  0%|                                                                                                                                                                                                                           | 0/50 [00:02<?, ?it/s]
Traceback (most recent call last):
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/route_utils.py", line 288, in call_process_api
    output = await app.get_blocks().process_api(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1931, in process_api
    result = await self.call_function(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/blocks.py", line 1528, in call_function
    prediction = await utils.async_iteration(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 671, in async_iteration
    return await iterator.__anext__()
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 664, in __anext__
    return await anyio.to_thread.run_sync(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2177, in run_sync_in_worker_thread
    return await future
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 859, in run
    result = context.run(func, *args)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 647, in run_sync_iterator_async
    return next(iterator)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/gradio/utils.py", line 809, in gen_wrapper
    response = next(iterator)
  File "/Volumes/ssd1/go/src/github.com/apple/ml-mdm/ml_mdm/clis/generate_sample.py", line 222, in generate
    for step, result in enumerate(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 525, in _sample
    x0, x_t, extra = self.get_xt_minus_1(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 655, in get_xt_minus_1
    p_t = self.forward_model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 752, in forward_model
    p_t = model(
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/diffusion.py", line 267, in forward
    pred_x0_low, _ = self.sampler.get_x0_eps_from_pred(x_t_low, pred_low, times)
  File "/Users/kaakaa/.pyenv/versions/3.10.14/lib/python3.10/site-packages/ml_mdm/samplers.py", line 342, in get_x0_eps_from_pred
    x0 = (x_t - pred * (1 - g).sqrt()) / g.sqrt()
RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 3