😭

【Flutter】Azure open AIを利用してStreamで結果を取得する【非推奨】

2023/10/26に公開

非推奨っぽいです。
そもそもアプリ側に環境変数でもopen AIのアクセストークンのような機密情報を含めてはいけないらしいです。
理由は、アプリのバイナリデータは逆コンパイルされることがあり、その場合に機密情報が漏洩する危険があるからなのだそう。

マジかよこっわ。アプリって逆コンパイルなんてできるんすか。。。
https://chat.openai.com/share/39dc66ad-42dd-495e-943f-f30730fa0aa5

というわけで、Open AIのアクセストークンはサーバーサイドに置いておき、クライアント側には置かないようにしましょう。

FlutterのOpenAI系のパッケージはいくつかありますが、これ全部環境変数にアクセストークンを設定しているので、セキュリティリスクあるってことですよね。
めっちゃスターついとるのに。。。

まぁ、ダウンサイドリスクはトークン盗まれてAPI使い込まれるくらいなので、ユーザーに迷惑かかるわけでもないし、そのリスクを取れるならパッケージを使うのも全然良いとは思いますが。

自分は一応、Rails API経由でopenAIにアクセスすることにしました。

RailsでストリームリクエストをするためにはActionCable使わなきゃいけないんですが、諸事情でActionCableが使えなくなってるので、とりあえずは普通のHTTPリクエストで作り直そうと思います。

なので以下は自分の無駄な努力です。

せっかく作ったので勿体無いし、あとでstreamを再実装するときに参考にするためにメモ

Flutterから直接Azure open AIにアクセスする方法がなかったので、自前で実装したものを公開します。
こういうの作りました。
AIにストリームリクエストで聞いて、ChatGPTみたく逐次的に回答を出力するやつです。

かなりハマったのは、openAIからstreamで返されるのは全部Stringであり、さらに返されたStringも不完全なJSONフォーマットである場合が多いことでした。
そのため最終的なAIの出力結果で文字が欠落することが多く、意味が通じなくなる場合も多かったのですね。

bufferedChunk はそのための変数です。
不完全なjsonを格納して、次に返される不完全なjsonと組み合わせて完全なjsonフォーマットを構築してからparseするためにあります。

以下、備忘録なのでコードだけ載せます。

ai_search.dart
class DictionaryAISearch extends StatefulWidget {
  const DictionaryAISearch(
      {super.key, required this.keyword, required this.dictionary});
  final String keyword;
  final Dictionary dictionary;

  
  State<DictionaryAISearch> createState() => _DictionaryAISearchState();
}

class _DictionaryAISearchState extends State<DictionaryAISearch> {
  bool _isRequesting = false;
  List<String> collectedChunks = []; // このリストにストリームからのチャンクを保存する。

  Stream<String?> chatGPTStream(
      {required String prompt, required int version}) {
    final StreamController<String?> controller = StreamController<String?>();
    String bufferedChunk = ""; // 不完全なJSONへの対策用

    bool isValidJson(String input) {
      try {
        json.decode(input);
        return true;
      } catch (e) {
        return false;
      }
    }

    Future<void> fetchData() async {
      try {
        final http.StreamedResponse streamedResponse =
            await ChatService.stream(prompt: prompt, version: version);

        if (streamedResponse.statusCode == 200) {
          streamedResponse.stream.transform(utf8.decoder).listen((chunk) {
            bufferedChunk += chunk; // 不完全なJSON対策に、バッファにデータを追加

            // バッファ内のチャンクを取得
            List<String> splitChunks = bufferedChunk
                .split('\n')
                .where((line) => line.trim().startsWith('data: '))
                .toList();

            for (String singleChunk in splitChunks) {
              String modifiedText =
                  singleChunk.replaceFirst("data:", "").trim();

              if (modifiedText == '[DONE]') {
                // END OF STREAM
                continue;
              }

              if (!isValidJson(modifiedText)) {
                // このチャンクが不完全なJSONである場合、このイテレーションをスキップして次のチャンクを待ちます。
                continue;
              }

              Map<String, dynamic> parsedData = json.decode(modifiedText);

              if (parsedData['choices'] != null &&
                  parsedData['choices'].isNotEmpty) {
                String? content = parsedData['choices'][0]['delta']['content'];
                print('content: $content');

                if (content != null) {
                  collectedChunks.add(content);
                }
                controller.add(content);
              }

              // ここでバッファからこのチャンクを削除
              bufferedChunk = bufferedChunk.replaceFirst(singleChunk, "");
            }
          }, onDone: () {
            controller.close();
          }, onError: (error) {
            controller.addError(error);
          });
        } else {
          throw Exception('Failed to load data');
        }
      } catch (e) {
        print("error $e");
        controller.addError(e);
        controller.close();
      }
    }

    fetchData();

    return controller.stream;
  }

  void _performAISearch() {
    setState(() => _isRequesting = true);
  }

  
  Widget build(BuildContext context) {
    if (_isRequesting) {
      return StreamBuilder<String?>(
        stream:
            chatGPTStream(prompt: '${widget.keyword}の意味を教えてください。', version: 3),
        builder: (context, snapshot) {
          // ストリームからの最初のデータが届いたときに、LoadingSpinnerを非表示にする
          if (snapshot.connectionState == ConnectionState.waiting) {
            return const LoadingSpinner();
          }

          if (snapshot.hasError) {
            return Text('Error: ${snapshot.error}');
          }
          // ここで結果を出力する。
          String fullReplyContent = collectedChunks.join();
          return SelectableText(fullReplyContent,
              style: const TextStyle(fontSize: 14, color: Colors.black87));
        },
      );
    } else {
      return ElevatedButton(
        onPressed: _performAISearch,
        child: const Text('AI Search'),
      );
    }
  }
}

Azure open AIにストリーミングリクエストを飛ばすサービスクラス。

chat_service.dart

// ChatGPT
// Azure open AIに対応しているSDKがないので、自前で実装する。
// ref: https://pub.aimind.so/step-by-step-guide-to-integrating-openai-api-in-flutter-f85cb0856a9d
// ref: https://blog.jbs.co.jp/entry/2023/06/21/130111
// ref: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_stream_completions.ipynb
class ChatService {
  // streamで取得する場合
  static Future<http.StreamedResponse> stream(
      {required String prompt, required int version}) async {
        final String? endpoint = dotenv.env['AZURE_OPENAI_endpoint'];
    final Map<String, String> deploymentAndModel =
        ChatService.deploymentAndModel(version: version);
    final String deployment = deploymentAndModel['deployment']!;
    final String model = deploymentAndModel['model']!;
    final Uri chatUri = Uri.parse(
        "https://$endpoint.openai.azure.com/openai/deployments/$deployment/chat/completions?api-version=2023-07-01-preview");

    final Map<String, String> headers = header();
    final ChatRequest request = ChatRequest(
        model: model,
        maxTokens: 500,
        stream: true,
        messages: [Message(role: "system", content: prompt)]);

    final http.Client client = http.Client();
    final http.Request httpRequest = http.Request('POST', chatUri)
      ..headers.addAll(headers)
      ..body = request.toJson();
    final http.StreamedResponse streamedResponse =
        await client.send(httpRequest);
    return streamedResponse;
  }

  static Map<String, String> header() {
    final String? openAIApiKey = dotenv.env['AZURE_OPENAI_ACCESS_TOKEN'];
    final Map<String, String> headers = {
      'Content-Type': 'application/json',
      'api-key': '$openAIApiKey',
    };
    return headers;
  }

  // バージョンごとのデプロイメントとモデルを取得する
  static Map<String, String> deploymentAndModel({required int version}) {
    switch (version) {
      case 3:
        return {
          'deployment': 'diqt-gpt-35-turbo',
          'model': 'gpt-35-turbo-0613'
        };
      case 4:
        return {'deployment': 'diqt-gpt-4', 'model': 'gpt-4-0613'};
      default:
        return {
          'deployment': 'diqt-gpt-35-turbo',
          'model': 'gpt-35-turbo-0613'
        };
    }
  }
}

モデル。

models/chat/request.dart
// ref: https://pub.aimind.so/step-by-step-guide-to-integrating-openai-api-in-flutter-f85cb0856a9d
class ChatRequest {
  final String model;
  final List<Message> messages;
  final int? maxTokens;
  final double? temperature;
  final int? topP;
  final int? n;
  final bool? stream;
  final double? presencePenalty;
  final double? frequencyPenalty;
  final String? stop;

  ChatRequest({
    required this.model,
    required this.messages,
    this.maxTokens,
    this.temperature,
    this.topP,
    this.n,
    this.stream,
    this.presencePenalty,
    this.frequencyPenalty,
    this.stop,
  });

  String toJson() {
    Map<String, dynamic> jsonBody = {
      'model': model,
      'messages': List<Map<String, dynamic>>.from(
          messages.map((message) => message.toJson())),
    };
    if (maxTokens != null) {
      jsonBody.addAll({'max_tokens': maxTokens});
    }

    if (temperature != null) {
      jsonBody.addAll({'temperature': temperature});
    }

    if (topP != null) {
      jsonBody.addAll({'top_p': topP});
    }

    if (n != null) {
      jsonBody.addAll({'n': n});
    }

    if (stream != null) {
      jsonBody.addAll({'stream': stream});
    }

    if (presencePenalty != null) {
      jsonBody.addAll({'presence_penalty': presencePenalty});
    }

    if (frequencyPenalty != null) {
      jsonBody.addAll({'frequency_penalty': frequencyPenalty});
    }

    if (stop != null) {
      jsonBody.addAll({'stop': stop});
    }

    return json.encode(jsonBody);
  }
}

class Message {
  final String? role;
  final String? content;

  Message({this.role, this.content});

  factory Message.fromJson(Map<String, dynamic> json) {
    return Message(
      role: json['role'],
      content: json['content'],
    );
  }

  Map<String, dynamic> toJson() {
    return {
      'role': role,
      'content': content,
    };
  }
}

Discussion